Squashed 'third_party/allwpilib/' changes from 66b57f032..e473a00f9
e473a00f9 [wpiutil] Base64: Add unsigned span/vector variants (#3702)
52f2d580e [wpiutil] raw_uv_ostream: Add reset() (#3701)
d7b1e3576 [wpiutil] WebSocket: move std::function (#3700)
93799fbe9 [examples] Fix description of TrapezoidProfileSubsystem (#3699)
b84644740 [wpimath] Document pose estimator states, inputs, and outputs (#3698)
2dc35c139 [wpimath] Fix classpaths for JNI class loads (#3697)
2cb171f6f [docs] Set Doxygen extract_all to true and fix Doxygen failures (#3695)
a939cd9c8 [wpimath] Print uncontrollable/unobservable models in LQR and KF (#3694)
d5270d113 [wpimath] Clean up C++ StateSpaceUtil tests (#3692)
b20903960 [wpimath] Remove redundant discretization tests from StateSpaceUtilTest (#3689)
c0cb545b4 [wpilibc] Add deprecated Doxygen attribute to SpeedController (#3691)
35c9f66a7 [wpilib] Rename PneumaticsHub to PneumaticHub (#3686)
796d03d10 [wpiutil] Remove unused LLVM header (#3688)
8723caf78 [wpilibj] Make Java TrapezoidProfile.Constraints an immutable class (#3687)
187f50a34 [wpimath] Catch incorrect parameters to state-space models earlier (#3680)
8d04606c4 Replace instances of frc-characterization with SysId (NFC) (#3681)
b82d4f6e5 [hal, cscore, ntcore] Use WPI common handle type base
87e34967e [wpiutil] Add synchronization primitives
e32499c54 [wpiutil] Add ParallelTcpConnector (#3655)
aa0b49228 [wpilib] Remove redundant "quick turn" docs for curvature drive (NFC) (#3674)
57301a7f9 [hal] REVPH: Start closed-loop compressor control on init (#3673)
d1842ea8f [wpilib] Improve interrupt docs (NFC) (#3679)
558151061 [wpiutil] Add DsClient (#3654)
181723e57 Replace `.to<double>()` and `.template to<double>()` with `.value()` (#3667)
6bc1db44b [commands] Add pointer overload of AddRequirements (artf6003) (#3669)
737b57ed5 [wpimath] Update to drake v0.35.0 (#3665)
4d287d1ae [build] Upgrade WPIJREArtifact to JRE 2022-11.0.12u5 (#3666)
f26eb5ada [hal] Fix another typo (get -> gets) (NFC) (#3663)
94ed275ba [hal] Fix misspelling (numer -> number) (NFC) (#3662)
ac2f44da3 [wpiutil] uv: use move for std::function (#3653)
75fa1fbfb [wpiutil] json::serializer: Optimize construction (#3647)
5e689faea [wpiutil] Import MessagePack implementation (mpack) (#3650)
649a50b40 [wpiutil] Add LEB128 byte-by-byte reader (#3652)
e94397a97 [wpiutil] Move json_serializer.h to public headers (#3646)
4ec58724d [wpiutil] uv::Tcp: Clarify SetNoDelay documentation (#3649)
8cb294aa4 [wpiutil] WebSocket: Make Shutdown() public (#3651)
2b3a9a52b [wpiutil] json: Fix map iterator key() for std::string_view (#3645)
138cbb94b [wpiutil] uv::Async: Add direct call for no-parameter specialization (#3648)
e56d6dea8 [ci] Update testbench pool image to ubuntu-latest (#3643)
43f30e44e [build] Enable comments in doxygen source files (#3644)
9e6db17ef [build] Enable doxygen preprocessor expansion of WPI_DEPRECATED (#3642)
0e631ad2f Add WPILib version to issue template (#3641)
6229d8d2f [build] Docs: set case_sense_names to false (#3392)
4647d09b5 [docs] Fix Doxygen warnings, add CI docs lint job (#3639)
4ad3a5402 [hal] Fix PWM allocation channel (#3637)
05e5feac4 [docs] Fix brownout docs (NFC) (#3638)
67df469c5 [examples] Remove old command-based templates and examples (#3263)
689e9ccfb [hal, wpilib] Add brownout voltage configuration (#3632)
9cd4bc407 [docs] Add useLocal to avoid using installer artifacts (#3634)
61996c2bb [cscore] Fix Java direct callback notifications (#3631)
6d3dd99eb [build] Update to newest native-utils (#3633)
f0b484892 [wpiutil] Fix StringMap iterator equality check (#3629)
8352cbb7a Update development build instructions for 2022 (#3616)
6da08b71d [examples] Fix Intermediate Vision Java Example description (#3628)
5d99059bf [wpiutil] Remove optional.h (#3627)
fa41b106a [glass, wpiutil] Add missing format args (#3626)
4e3fd7d42 [build] Enable Zc:__cplusplus for Windows (#3625)
791d8354d [build] Suppress deprecation/removal warnings for old commands (#3618)
10f19e6fc [hal, wpilib] Add REV PneumaticsHub (#3600)
4c61a1305 [ntcore] Revert to per-element copy for toNative() (#3621)
7b3f62244 [wpiutil] SendableRegistry: Print exception stacktrace (#3620)
d347928e4 [hal] Use better error for when console out is enabled while attempting to use onboard serial port (#3622)
cc31079a1 [hal] Use setcap instead of setuid for setting thread priorities (#3613)
4676648b7 [wpimath] Upgrade to Drake v0.34.0 (#3607)
c7594c911 [build] Allow building wpilibc in cmake without cscore and opencv (#3605)
173cb7359 [wpilib] Add TimesliceRobot (#3502)
af295879f [hal] Set error status for I2C port out of range (#3603)
95dd20a15 [build] Enable spotbugs (#3601)
b65fce86b [wpilib] Remove Timer lock in wpilibj and update docs (#3602)
3b8d3bbcb Remove unused and add missing deprecated.h includes (#3599)
f9e976467 [examples] Rename DriveTrain classes to Drivetrain (#3594)
118a27be2 [wpilib] Add Timer tests (#3595)
59c89428e [wpilib] Deprecate Timer::HasPeriodPassed() (#3598)
202ca5e78 Force C++17 in .clang-format (#3597)
d6f185d8e Rename tests for consistency (#3592)
54ca474db [ci] Enable asan and tsan in CI for tests that pass (#3591)
1ca383b23 Add Debouncer (#3590)
179fde3a7 [build] Update to 2022 native utils and gradle 7 (#3588)
50198ffcf [examples] Add Mechanism2d visualization to Elevator Sim (#3587)
a446c2559 [examples] Synchronize C++ and Java Mechanism2d examples (#3589)
a7fb83103 [ci] clang-tidy: Generate compilation commands DB with Gradle (#3585)
4f5e0c9f8 [examples] Update ArmSimulation example to use Mechanism2d (#3572)
8164b91dc [CI] Print CMake test output on failure (#3583)
4d5fca27e [wpilib] Impove Mechanism2D documentation (NFC) (#3584)
fe59e4b9f Make C++ test names more consistent (#3586)
5c8868549 [wpilibc] Fix C++ MechanisimRoot2D to use same NT entries as Java/Glass (#3582)
9359431ba [wpimath] Clean up Eigen usage
72716f51c [wpimath] Upgrade to Eigen 3.4
382deef75 [wpimath] Explicitly export wpimath symbols
161e21173 [ntcore] Match standard handle layout, only allow 16 instances (#3577)
263a24811 [wpimath] Use jinja for codegen (#3574)
725251d29 [wpilib] Increase tolerances of DifferentialDriveSimTest (#3581)
4dff87301 [wpimath] Make LinearFilter::Factorial() constexpr (#3579)
60ede67ab [hal, wpilib] Switch PCM to be a single object that is allowed to be duplicated (#3475)
906bfc846 [build] Add CMake build support for sanitizers (#3576)
0d4f08ad9 [hal] Simplify string copy of joystick name (#3575)
a52bf87b7 [wpiutil] Add Java function package (#3570)
40c7645d6 [wpiutil] UidVector: Return old object from erase() (#3571)
5b886a23f [wpiutil] jni_util: Add size, operator[] to JArrayRef (#3569)
65797caa7 [sim] Fix halsim_ds_socket stringop overflow warning from GCC 10 (#3568)
66abb3988 [hal] Update runtime enum to allow selecting roborio 2 (#3565)
95a12e0ee [hal] UidSetter: Don't revert euid if its already current value (#3566)
27951442b [wpimath] Use external Eigen headers only (#3564)
c42e053ae [docs] Update to doxygen 1.9.2 (#3562)
e7048c8c8 [docs] Disable doxygen linking for common words that are also classes (#3563)
d8e0b6c97 [wpilibj] Fix java async interrupts (#3559)
5e6c34c61 Update to 2022 roborio image (#3537)
828f073eb [wpiutil] Fix uv::Buffer memory leaks caught by asan (#3555)
2dd5701ac [cscore] Fix mutex use-after-free in cscore test (#3557)
531439198 [ntcore] Fix NetworkTables memory leaks caught by asan (#3556)
3d9a4d585 [wpilibc] Fix AnalogTriggerOutput memory leak reported by asan (#3554)
54eda5928 [wpiutil] Ignore ubsan vptr upcast warning in SendableHelper moves (#3553)
5a4f75c9f [wpilib] Replace Speed controller comments with motor controller (NFC) (#3551)
7810f665f [wpiutil] Fix bug in uleb128 (#3540)
697e2dd33 [wpilib] Fix errant jaguar reference in comments (NFC) (#3550)
936c64ff5 [docs] Enable -linksource for javadocs (#3549)
1ea654954 [build] Upgrade CMake build to googletest 1.11.0 (#3548)
32d9949e4 [wpimath] Move controller tests to wpimath (#3541)
01ba56a8a [hal] Replace strncpy with memcpy (#3539)
e109c4251 [build] Rename makeSim flag to forceGazebo to better describe what it does (#3535)
e4c709164 [docs] Use a doxygen theme and add logo (#3533)
960b6e589 [wpimath] Fix Javadoc warning (#3532)
82eef8d5e [hal] Remove over current fault HAL functions from REV PDH (#3526)
aa3848b2c [wpimath] Move RobotDriveBase::ApplyDeadband() to MathUtil (#3529)
3b5d0d141 [wpimath] Add LinearFilter::BackwardFiniteDifference() (#3528)
c8fc715fe [wpimath] Upgrade drake files to v0.33.0 (#3531)
e5fe3a8e1 [build] Treat javadoc warnings as errors in CI and fix warnings (#3530)
e0c6cd3dc [wpimath] Add an operator for composing two Transform2ds (#3527)
2edd510ab [sim] Add sim wrappers for sensors that use SimDevice (#3517)
2b3e2ebc1 [hal] Fix HAL Notifier thread priority setting (#3522)
ab4cb5932 [gitignore] Update gitignore to ignore bazel / clion files (#3524)
57c8615af [build] Generate spotless patch on failure (#3523)
b90317321 Replace std::cout and std::cerr with fmt::print() (#3519)
10cc8b89c [hal] [wpilib] Add initial support for the REV PDH (#3503)
5d9ae3cdb [hal] Set HAL Notifier thread as RT by default (#3482)
192d251ee [wpilibcIntegrationTests] Properly disable DMA integration tests (#3514)
031962608 [wpilib] Add PS4Controller, remove Hand from GenericHID/XboxController (#3345)
25f6f478a [wpilib] Rename DriverStation::IsOperatorControl() to IsTeleop() (#3505)
e80f09f84 [wpilibj] Add unit tests (#3501)
c159f91f0 [wpilib] Only read DS control word once in IterativeRobotBase (#3504)
eb790a74d Add rio development docs documenting myRobot deploy tasks (#3508)
e47451f5a [wpimath] Replace auto with Eigen types (#3511)
252b8c83b Remove Java formatting from build task in CI (#3507)
09666ff29 Shorten Gazebo CI build (#3506)
baf2e501d Update myRobot to use 2021 java (#3509)
5ac60f0a2 [wpilib] Remove IterativeRobotBase mode init prints (#3500)
fb2ee8ec3 [wpilib] Add TimedRobot functions for running code on mode exit (#3499)
94e0db796 [wpilibc] Add more unit tests (#3494)
b25324695 [wpilibj] Add units to parameter names (NFC) (#3497)
1ac73a247 [hal] Rename PowerDistributionPanel to PowerDistribution (#3466)
2014115bc [examples] frisbeebot: Fix typo and reflow comments (NFC) (#3498)
4a944dc39 [examples] Consistently use 0 for controller port (#3496)
3838cc4ec Use unicode characters in docs equations (#3487)
85748f2e6 [examples] Add C++ TankDrive example (#3493)
d7b8aa56d [wpilibj] Rename DriverStation In[Mode] functions to follow style guide (#3488)
16e096cf8 [build] Fix CMake Windows CI (#3490)
50af74c38 [wpimath] Clean up NumericalIntegration and add Discretization tests (#3489)
bfc209b12 Automate fmt update (#3486)
e7f9331e4 [build] Update to Doxygen 1.9.1 (#3008)
ab8e8aa2a [wpimath] Update drake with upstream (#3484)
1ef826d1d [wpimath] Fix IOException path in WPIMath JNI (#3485)
52bddaa97 [wpimath] Disable iostream support for units and enable fmtlib (#3481)
e4dc3908b [wpiutil] Upgrade to fmtlib 8.0.1 (#3483)
1daadb812 [wpimath] Implement Dormand-Prince integration method (#3476)
9c2723391 [cscore] Add [[nodiscard]] to GrabFrame functions (#3479)
7a8796414 [wpilib] Add Notifier integration tests (#3480)
f8f13c536 [wpilibcExamples] Prefix decimal numbers with 0 (#3478)
1adb69c0f [ntcore] Use "NetworkTables" instead of "Network Tables" in NT specs (#3477)
5f5830b96 Upload wpiformat diff if one exists (#3474)
9fb4f35bb [wpimath] Add tests for DARE overload with Q, R, and N matrices (#3472)
c002e6f92 Run wpiformat (#3473)
c154e5262 [wpilib] Make solenoids exclusive use, PCM act like old sendable compressor (#3464)
6ddef1cca [hal] JNI setDIO: use a boolean and not a short (#3469)
9d68d9582 Remove extra newlines after open curly braces (NFC) (#3471)
a4233e1a1 [wpimath] Add script for updating Drake (#3470)
39373c6d2 Update README.md for new GCC version requirement (#3467)
d29acc90a [wpigui] Add option to reset UI on exit (#3463)
a371235b0 [ntcore] Fix dangling pointer in logger (#3465)
53b4891a5 [wpilibcintegrationtests] Fix deprecated Preferences usage (#3461)
646ded912 [wpimath] Remove incorrect discretization in pose estimators (#3460)
ea0b8f48e Fix some deprecation warnings due to fmtlib upgrade (#3459)
2067d7e30 [wpilibjexamples] Add wpimathjni, wpiutiljni to library path (#3455)
866571ab4 [wpiutil] Upgrade to fmtlib 8.0.0 (#3457)
4e1fa0308 [build] Skip PDB copy on windows build servers (#3458)
b45572167 [build] Change CI back to 18.04 docker images (#3456)
57a160f1b [wpilibc] Fix LiveWindow deprecation warning in RobotBase skeleton template (#3454)
29ae8640d [HLT] Implement duty cycle cross connect tests (#3453)
ee6377e54 [HLT] Add relay and analog cross connects (#3452)
b0f1ae7ea [build] CMake: Build the HAL even if WITH_CSCORE=OFF (#3449)
7aae2b72d Replace std::to_string() with fmt::format() (#3451)
73fcbbd74 [HLT] Add relay digital cross connect tests (#3450)
e7bedde83 [HLT] Add PWM tests that use DMA as the back end (#3447)
7253edb1e [wpilibc] Timer: Fix deprecated warning (#3446)
efa28125c [wpilibc] Add message to RobotBase on how to read stacktrace (#3444)
9832fcfe1 [hal] Fix DIO direction getter (#3445)
49c71f9f2 [wpilibj] Clarify robot quit message (#3364)
791770cf6 [wpimath] Move controller from wpilibj to wpimath (#3439)
9ce9188ff [wpimath] Add ReportWarning to MathShared (#3441)
362066a9b [wpilib] Deprecate getInstance() in favor of static functions (#3440)
26ff9371d Initial commit of cross connect integration test project (#3434)
4a36f86c8 [hal] Add support for DMA to Java (#3158)
85144e47f [commands] Unbreak build (#3438)
b417d961e Split Sendable into NT and non-NT portions (#3432)
ef4ea84cb [commands] Change grouping decorator impl to flatten nested group structures (#3335)
b422665a3 [examples] Invert right side of drive subsystems (#3437)
186dadf14 [hal] Error if attempting to set DIO output on an input port (#3436)
04e64db94 Remove redundant C++ lambda parentheses (NFC) (#3433)
f60994ad2 [wpiutil] Rename Java package to edu.wpi.first.util (#3431)
cfa1ca96f [wpilibc] Make ShuffleboardValue non-copyable (#3430)
4d9ff7643 Fix documentation warnings generated by JavaDoc (NFC) (#3428)
9e1b7e046 [build] Fix clang-tidy and clang-format (#3429)
a77c6ff3a [build] Upgrade clang-format and clang-tidy (NFC) (#3422)
099fde97d [wpilib] Improve PDP comments (NFC) (#3427)
f8fc2463e [wpilibc, wpiutil] Clean up includes (NFC) (#3426)
e246b7884 [wpimath] Clean up member initialization in feedforward classes (#3425)
c1e128bd5 Disable frivolous PMD warnings and enable PMD in ntcore (#3419)
8284075ee Run "Lint and Format" CI job on push as well as pull request (#3412)
f7db09a12 [wpimath] Move C++ filters into filter folder to match Java (#3417)
f9c3d54bd [wpimath] Reset error covariance in pose estimator ResetPosition() (#3418)
0773f4033 [hal] Ensure HAL status variables are initialized to zero (#3421)
d068fb321 [build] Upgrade CI to use 20.04 docker images (#3420)
8d054c940 [wpiutil] Remove STLExtras.h
80f1d7921 [wpiutil] Split function_ref to a separate header
64f541325 Use wpi::span instead of wpi::ArrayRef across all libraries (#3414)
2abbbd9e7 [build] clang-tidy: Remove bugprone-exception-escape (#3415)
a5c471af7 [wpimath] Add LQR template specialization for 2x2 system
edd2f0232 [wpimath] Add DARE solver for Q, R, and N with LQR ctor overloads
b2c3b2dd8 Use std::string_view and fmtlib across all libraries (#3402)
4f1cecb8e [wpiutil] Remove Path.h (#3413)
b336eac34 [build] Publish halsim_ws_core to Maven
2a09f6fa4 [build] Also build sim modules as static libraries
0e702eb79 [hal] Add a unified PCM object (#3331)
dea841103 [wpimath] Add fmtlib formatter overloads for Eigen::Matrix and units (#3409)
82856cf81 [wpiutil] Improve wpi::circular_buffer iterators (#3410)
8aecda03e [wpilib] Fix a documentation typo (#3408)
5c817082a [wpilib] Remove InterruptableSensorBase and replace with interrupt classes (#2410)
15c521a7f [wpimath] Fix drivetrain system identification (#3406)
989de4a1b [build] Force all linker warnings to be fatal for rio builds (#3407)
d9eeb45b0 [wpilibc] Add units to Ultrasonic class API (#3403)
fe570e000 [wpiutil] Replace llvm filesystem with C++17 filesystem (#3401)
01dc0249d [wpimath] Move SlewRateLimiter from wpilib to wpimath (#3399)
93523d572 [wpilibc] Clean up integration tests (#3400)
4f7a4464d [wpiutil] Rewrite StringExtras for std::string_view (#3394)
e09293a15 [wpilibc] Transition C++ classes to units::second_t (#3396)
827b17a52 [build] Create run tasks for Glass and OutlineViewer (#3397)
a61037996 [wpiutil] Avoid MSVC warning on span include (#3393)
4e2c3051b [wpilibc] Use std::string_view instead of Twine (#3380)
50915cb7e [wpilibc] MotorSafety::GetDescription(): Return std::string (#3390)
f4e2d26d5 [wpilibc] Move NullDeleter from frc/Base.h to wpi/NullDeleter.h (#3387)
cb0051ae6 [wpilibc] SimDeviceSim: use fmtlib (#3389)
a238cec12 [wpiutil] Deprecate wpi::math constants in favor of wpi::numbers (#3383)
393bf23c0 [ntcore, cscore, wpiutil] Standardize template impl files on .inc extension (NFC) (#3124)
e7d9ba135 [sim] Disable flaky web server integration tests (#3388)
0a0003c11 [wpilibjExamples] Fix name of Java swerve drive pose estimator example (#3382)
7e1b27554 [wpilibc] Use default copies and moves when possible (#3381)
fb2a56e2d [wpilibc] Remove START_ROBOT_CLASS macro (#3384)
84218bfb4 [wpilibc] Remove frc namespace shim (#3385)
dd7824340 [wpilibc] Remove C++ compiler version static asserts (#3386)
484cf9c0e [wpimath] Suppress the -Wmaybe-uninitialized warning in Eigen (#3378)
a04d1b4f9 [wpilibc] DriverStation: Remove ReportError and ReportWarning
831c10bdf [wpilibc] Errors: Use fmtlib
87603e400 [wpiutil] Import fmtlib (#3375)
442621672 [wpiutil] Add ArrayRef/std::span/wpi::span implicit conversions
bc15b953b [wpiutil] Add std::span implementation
6d20b1204 [wpiutil] StringRef, Twine, raw_ostream: Add std::string_view support (#3373)
2385c2a43 [wpilibc] Remove Utility.h (#3376)
87384ea68 [wpilib] Fix PIDController continuous range error calculations (#3170)
04dae799a [wpimath] Add SimpleMotorFeedforward::Calculate(velocity, nextVelocity) overload (#3183)
0768c3903 [wpilib] DifferentialDrive: Remove right side inversion (#3340)
8dd8d4d2d [wpimath] Fix redundant nested math package introduced by #3316 (#3368)
49b06beed [examples] Add Field2d to RamseteController example (#3371)
4c562a445 [wpimath] Fix typo in comment of update_eigen.py (#3369)
fdbbf1188 [wpimath] Add script for updating Eigen
f1e64b349 [wpimath] Move Eigen unsupported folder into eigeninclude
224f3a05c [sim] Fix build error when building with GCC 11.1 (#3361)
ff56d6861 [wpilibj] Fix SpeedController deprecated warnings (#3360)
1873fbefb [examples] Fix Swerve and Mecanum examples (#3359)
80b479e50 [examples] Fix SwerveBot example to use unique encoder ports (#3358)
1f7c9adee [wpilibjExamples] Fix pose estimator examples (#3356)
9ebc3b058 [outlineviewer] Change default size to 600x400 (#3353)
e21b443a4 [build] Gradle: Make C++ examples runnable (#3348)
da590120c [wpilibj] Add MotorController.setVoltage default (#3347)
561d53885 [build] Update opencv to 4.5.2, imgui/implot to latest (#3344)
44ad67ca8 [wpilibj] Preferences: Add missing Deprecated annotation (#3343)
3fe8fc75a [wpilibc] Revert "Return reference from GetInstance" (#3342)
3cc2da332 Merge branch '2022'
a3cd90dd7 [wpimath] Fix classpath used by generate_numbers.py (#3339)
d6cfdd3ba [wpilib] Preferences: Deprecate Put* in favor of Set* (#3337)
ba08baabb [wpimath] Update Drake DARE solver to v0.29.0 (#3336)
497b712f6 [wpilib] Make IterativeRobotBase::m_period private with getter
f00dfed7a [wpilib] Remove IterativeRobot base class
3c0846168 [hal] Use last error reporting instead of PARAMETER_OUT_OF_RANGE (#3328)
5ef2b4fdc [wpilibj] Fix @deprecated warning for SerialPort constructor (#3329)
23d2326d1 [hal] Report previous allocation location for indexed resource duplicates (#3322)
e338f9f19 [build] Fix wpilibc runCpp task (#3327)
c8ff626fe [wpimath] Move Java classes to edu.wpi.first.math (#3316)
4e424d51f [wpilibj] DifferentialDrivetrainSim: Rename constants to match the style guide (#3312)
6b50323b0 [cscore] Use Lock2DSize if possible for Windows USB cameras (#3326)
65c148536 [wpilibc] Fix "control reaches end of non-void function" warning (#3324)
f99f62bee [wpiutil] uv Handle: Use malloc/free instead of new/delete (#3325)
365f5449c [wpimath] Fix MecanumDriveKinematics (#3266)
ff52f207c [glass, wpilib] Rewrite Mechanism2d (#3281)
ee0eed143 [wpimath] Add DCMotor factory function for Romi motors (#3319)
512738072 [hal] Add HAL_GetLastError to enable better error messages from HAL calls (#3320)
ced654880 [glass, outlineviewer] Update Mac icons to macOS 11 style (#3313)
936d3b9f8 [templates] Add Java template for educational robot (#3309)
6e31230ad [examples] Fix odometry update in SwerveControllerCommand example (#3310)
05ebe9318 Merge branch 'main' into 2022
aaf24e255 [wpilib] Fix initial heading behavior in HolonomicDriveController (#3290)
8d961dfd2 [wpilibc] Remove ErrorBase (#3306)
659b37ef9 [wpiutil] StackTrace: Include offset on Linux (#3305)
0abf6c904 [wpilib] Move motor controllers to motorcontrol package (#3302)
4630191fa [wpiutil] circular_buffer: Use value initialization instead of passing zero (#3303)
b7b178f49 [wpilib] Remove Potentiometer interface
687066af3 [wpilib] Remove GyroBase
6b168ab0c [wpilib] Remove PIDController, PIDOutput, PIDSource
948625de9 [wpimath] Document conversion from filter cutoff frequency to time constant (#3299)
3848eb8b1 [wpilibc] Fix flywhel -> flywheel typo in FlywheelSim (#3298)
3abe0b9d4 [cscore] Move java package to edu.wpi.first.cscore (#3294)
d7fabe81f [wpilib] Remove RobotDrive (#3295)
1dc81669c [wpilib] Remove GearTooth (#3293)
01d0e1260 [wpilib] Revert move of RomiGyro into main wpilibc/j (#3296)
397e569aa [ntcore] Remove "using wpi" from nt namespace
79267f9e6 [ntcore] Remove NetworkTable -> nt::NetworkTable shim
48ebe5736 [ntcore] Remove deprecated Java interfaces and classes
c2064c78b [ntcore] Remove deprecated ITable interfaces
36608a283 [ntcore] Remove deprecated C++ APIs
a1c87e1e1 [glass] LogView: Add "copy to clipboard" button (#3274)
fa7240a50 [wpimath] Fix typo in quintic spline basis matrix
ffb4d38e2 [wpimath] Add derivation for spline basis matrices
f57c188f2 [wpilib] Add AnalogEncoder(int) ctor (#3273)
8471c4fb2 [wpilib] FieldObject2d: Add setTrajectory() method (#3277)
c97acd18e [glass] Field2d enhancements (#3234)
ffb590bfc [wpilib] Fix Compressor sendable properties (#3269)
6137f98eb [hal] Rename SimValueCallback2 to SimValueCallback (#3212)
a6f653969 [hal] Move registerSimPeriodic functions to HAL package (#3211)
10c038d9b [glass] Plot: Fix window creation after removal (#3264)
2d2eaa3ef [wpigui] Ensure window will be initially visible (#3256)
4d28b1f0c [wpimath] Use JNI for trajectory serialization (#3257)
3de800a60 [wpimath] TrajectoryUtil.h: Comment formatting (NFC) (#3262)
eff592377 [glass] Plot: Don't overwrite series ID (#3260)
a79faace1 [wpilibc] Return reference from GetInstance (#3247)
9550777b9 [wpilib] PWMSpeedController: Use PWM by composition (#3248)
c8521a3c3 [glass] Plot: Set reasonable default window size (#3261)
d71eb2cf3 [glass] Plot: Show full source name as tooltip and in popup (#3255)
160fb740f [hal] Use std::lround() instead of adding 0.5 and truncating (#3012)
48e9f3951 [wpilibj] Remove wpilibj package CameraServer (#3213)
8afa596fd [wpilib] Remove deprecated Sendable functions and SendableBase (#3210)
d3e45c297 [wpimath] Make C++ geometry classes immutable (#3249)
2c98939c1 [glass] StringChooser: Don't call SameLine() at end
a18a7409f [glass] NTStringChooser: Clear value of deleted entries
2f19cf452 [glass] NetworkTablesHelper: listen to delete events
da96707dc Merge branch 'main' into 2022
c3a8bdc24 [build] Fix clang-tidy action (#3246)
21624ef27 Add ImGui OutlineViewer (#3220)
1032c9b91 [wpiutil] Unbreak wpi::Format on Windows (#3242)
2e07902d7 [glass] NTField2D: Fix name lookup (#3233)
6e23e1840 [wpilibc] Remove WPILib.h (#3235)
3e22e4506 [wpilib] Make KoP drivetrain simulation weight 60 lbs (#3228)
79d1bd6c8 [glass] NetworkTablesSetting: Allow disable of server option (#3227)
fe341a16f [examples] Use more logical elevator setpoints in GearsBot (#3198)
62abf46b3 [glass] NetworkTablesSettings: Don't block GUI (#3226)
a95a5e0d9 [glass] Move NetworkTablesSettings to libglassnt (#3224)
d6f6ceaba [build] Run Spotless formatter (NFC) (#3221)
0922f8af5 [commands] CommandScheduler.requiring(): Note return can be null (NFC) (#2934)
6812302ff [examples] Make DriveDistanceOffboard example work in sim (#3199)
f3f86b8e7 [wpimath] Add pose estimator overload for vision + std dev measurement (#3200)
1a2680b9e [wpilibj] Change CommandBase.withName() to return CommandBase (#3209)
435bbb6a8 [command] RamseteCommand: Output 0 if interrupted (#3216)
3cf44e0a5 [hal] Add function for changing HAL Notifier thread priority (#3218)
40b367513 [wpimath] Units.java: Add kg-lb conversions (#3203)
9f563d584 [glass] NT: Fix return value in StringToDoubleArray (#3208)
af4adf537 [glass] Auto-size plots to fit window (#3193)
2560146da [sim] GUI: Add option to show prefix in Other Devices (#3186)
eae3a6397 gitignore: Ignore .cache directory (#3196)
959611420 [wpilib] Require non-zero positive value for PIDController.period (#3175)
9522f2e8c [wpimath] Add methods to concatenate trajectories (#3139)
e42a0b6cf [wpimath] Rotation2d comment formatting (NFC) (#3162)
d1c7032de [wpimath] Fix order of setting gyro offset in pose estimators (#3176)
d241bc81a [sim] Add DoubleSolenoidSim and SolenoidSim classes (#3177)
cb7f39afa [wpilibc] Add RobotController::GetBatteryVoltage() to C++ (#3179)
99b5ad9eb [wpilibj] Fix warnings that are not unused variables or deprecation (#3161)
c14b23775 [build] Fixup doxygen generated include dirs to match what users would need (#3154)
d447c7dc3 [sim] Add SimDeviceSim ctor overloads (#3134)
247420c9c [build] Remove jcenter repo (#3157)
04b112e00 [build] Include debug info in plugin published artifacts (#3149)
be0ce9900 [examples] Use PWMSparkMax instead of PWMVictorSPX (#3156)
69e8d0b65 [wpilib] Move RomiGyro into main wpilibc/j (#3143)
94e685e1b [wpimath] Add custom residual support to EKF (#3148)
5899f3dd2 [sim] GUI: Make keyboard settings loading more robust (#3167)
f82aa1d56 [wpilib] Fix HolonomicDriveController atReference() behavior (#3163)
fe5c2cf4b [wpimath] Remove ControllerUtil.java (#3169)
43d40c6e9 [wpiutil] Suppress unchecked cast in CombinedRuntimeLoader (#3155)
3d44d8f79 [wpimath] Fix argument order in UKF docs (NFC) (#3147)
ba6fe8ff2 [cscore] Add USB camera change event (#3123)
533725888 [build] Tweak OpenCV cmake search paths to work better on Linux (#3144)
29bf9d6ef [cscore] Add polled support to listener
483beb636 [ntcore] Move CallbackManager to wpiutil
fdaec7759 [examples] Instantiate m_ramseteController in example (#3142)
8494a5761 Rename default branch to main (#3140)
45590eea2 [wpigui] Hardcode window scale to 1 on macOS (#3135)
834a64920 [build] Publish libglass and libglassnt to Maven (#3127)
2c2ccb361 [wpimath] Fix Rotation2d equality operator (#3128)
fb5c8c39a [wpigui] clang-tidy: readability-braces-around-statements
f7d39193a [wpigui] Fix copyright in pfd and wpigui_metal.mm
aec796b21 [ntcore] Fix conditional jump on uninitialized value (#3125)
fb13bb239 [sim] GUI: Add right click popup for keyboard joystick settings (#3119)
c517ec677 [build] Update thirdparty-imgui to 1.79-2 (#3118)
e8cbf2a71 [wpimath] Fix typo in SwerveDrivePoseEstimator doc (NFC) (#3112)
e9c86df46 [wpimath] Add tests for swerve module optimization (#3100)
6ba8c289c [examples] Remove negative of ArcadeDrive(fwd, ..) in the C++ Getting Started Example (#3102)
3f1672e89 [hal] Add SimDevice createInt() and createLong() (#3110)
15be5cbf1 [examples] Fix segfault in GearsBot C++ example (#3111)
4cf0e5e6d Add quick links to API documentation in README (#3082)
6b1898f12 Fix RT priority docs (NFC) (#3098)
b3426e9c0 [wpimath] Fix missing whitespace in pose estimator doc (#3097)
38c1a1f3e [examples] Fix feildRelative -> fieldRelative typo in XControllerCommand examples (#3104)
4488e25f1 [glass] Shorten SmartDashboard window names (#3096)
cfdb3058e [wpilibj] Update SimDeviceSimTest (#3095)
64adff5fe [examples] Fix typo in ArcadeDrive constructor parameter name (#3092)
6efc58e3d [build] Fix issues with build on windows, deprecations, and native utils (#3090)
f393989a5 [wpimath, wpiutil] Add wpi::array for compile time size checking (#3087)
d6ed20c1e [build] Set macOS deployment target to 10.14 (#3088)
7c524014c [hal] Add [[nodiscard]] to HAL_WaitForNotifierAlarm() (#3085)
406d055f0 [wpilib] Fixup wouldHitLowerLimit in elevator and arm simulation classes. (#3076)
04a90b5dd [examples] Don't continually set setpoint in PotentiometerPID Examples (#3084)
8c5bfa013 [sim] GUI: Add max value setting for keyboard joysticks (#3083)
bc80c5535 [hal] Add SimValue reset() function (#3064)
9c3b51ca0 [wpilib] Document simulation APIs (#3079)
26584ff14 [wpimath] Add model description to LinearSystemId Javadocs (#3080)
42c3d5286 [examples] Sync Java and C++ trajectories in sim example (#3081)
64e72f710 [wpilibc] Add missing function RoboRioSim::ResetData (#3073)
e95503798 [wpimath] Add optimize() to SwerveModuleState (#3065)
fb99910c2 [hal] Add SimInt and SimLong wrappers for int/long SimValue (#3066)
e620bd4d3 [doc] Add machine-readable websocket specification (#3059)
a44e761d9 [glass] Add support for plot Y axis labels
ea1974d57 [wpigui] Update imgui and implot to latest
85a0bd43c [wpimath] Add RKF45 integration (#3047)
278e0f126 [glass] Use .controllable to set widgets' read-only state (#3035)
d8652cfd4 [wpimath] Make Java DCMotor API consistent with C++ and fix motor calcs (#3046)
377b7065a [build] Add toggleOffOn to Java spotless (#3053)
1e9c79c58 [sim] Use plant output to retrieve simulated position (#3043)
78147aa34 [sim] GUI: Fix Keyboard Joystick (#3052)
cd4a2265b [ntcore] Fix NetworkTableEntry::GetRaw() (#3051)
767ac1de1 [build] Use deploy key for doc publish (#3048)
d762215d1 [build] Add publish documentation script (#3040)
1fd09593c [examples] Add missing TestInit method to GettingStarted Example (#3039)
e45a0f6ce [examples] Add RomiGyro to the Romi Reference example (#3037)
94f852572 Update imaging link and fix typo (#3038)
d73cf64e5 [examples] Update RomiReference to match motor directions (#3036)
f945462ba Bump copyright year to 2021 (#3033)
b05946175 [wpimath] Catch Drake JNI exceptions and rethrow them (#3032)
62f0f8190 [wpimath] Deduplicate angle modulus functions (#2998)
bf8c0da4b [glass] Add "About" popup with version number (#3031)
dfdd6b389 [build] Increase Gradle heap size in Gazebo build (#3028)
f5e0fc3e9 Finish clang-tidy cleanups (#3003)
d741101fe [sim] Revert accidental commit of WSProvider_PDP.h (#3027)
e1620799c [examples] Add C++ RomiReference example (#2969)
749c7adb1 [command] Fix use-after-free in CommandScheduler (#3024)
921a73391 [sim] Add WS providers for AddressableLED, PCM, and Solenoid (#3026)
26d0004fe [build] Split Actions into different yml files (#3025)
948af6d5b [wpilib] PWMSpeedController.get(): Apply Inversion (#3016)
670a187a3 [wpilibc] SuppliedValueWidget.h: Forward declare ShuffleboardContainer (#3021)
be9f72502 [ntcore] NetworkTableValue: Use std::forward instead of std::move (#3022)
daf3f4cb1 [cscore] cscore_raw_cv.h: Fix error in PutFrame() (#3019)
5acda4cc7 [wpimath] ElevatorFeedforward.h: Add time.h include
8452af606 [wpimath] units/radiation.h: Add mass.h include
630d44952 [hal] ErrorsInternal.h: Add stdint.h include
7372cf7d9 [cscore] Windows NetworkUtil.cpp: Add missing include
b7e46c558 Include .h from .inc/.inl files (NFC) (#3017)
bf8f8710e [examples] Update Romi template and example (#2996)
6ffe5b775 [glass] Ensure NetworkTableTree parent context menu has an id (#3015)
be0805b85 [build] Update to WPILibVersioningPlugin 4.1.0 (#3014)
65b2359b2 [build] Add spotless for other files (#3007)
8651aa73e [examples] Enable NT Flush in Field2d examples (#3013)
78b542737 [build] Add Gazebo build to Actions CI (#3004)
fccf86532 [sim] DriverStationGui: Fix two bugs (#3010)
185741760 [sim] WSProvider_Joystick: Fix off-by-1 in incoming buttons (#3011)
ee7114a58 [glass] Add drive class widgets (#2975)
00fa91d0d [glass] Use ImGui style for gyro widget colors (#3009)
b7a25bfc3 ThirdPartyNotices: Add portable file dialogs license (#3005)
a2e46b9a1 [glass] modernize-use-nullptr (NFC) (#3006)
a751fa22d [build] Apply spotless for java formatting (#1768)
e563a0b7d [wpimath] Make LinearSystemLoop move-constructible and move-assignable (#2967)
49085ca94 [glass] Add context menus to remove and add NetworkTables values (#2979)
560a850a2 [glass] Add NetworkTables Log window (#2997)
66782e231 [sim] Create Left/Right drivetrain current accessors (#3001)
b60eb1544 clang-tidy: bugprone-virtual-near-miss
cbe59fa3b clang-tidy: google-explicit-constructor
c97c6dc06 clang-tidy: google-readability-casting (NFC)
32fa97d68 clang-tidy: modernize-use-nullptr (NFC)
aee460326 clang-tidy: modernize-pass-by-value
29c7da5f1 clang-tidy: modernize-make-unique
6131f4e32 clang-tidy: modernize-concat-nested-namespaces (NFC)
67e03e625 clang-tidy: modernize-use-equals-default
b124f9101 clang-tidy: modernize-use-default-member-init
d11a3a638 clang-tidy: modernize-use-override (NFC)
4cc0706b0 clang-tidy: modernize-use-using (NFC)
885f5a978 [wpilibc] Speed up ScopedTracerTest (#2999)
60b596457 [wpilibj] Fix typos (NFC) (#3000)
6e1919414 [build] Bring naming checkstyle rules up to date with Google Style guide (#1781)
8c8ec5e63 [wpilibj] Suppress unchecked cast warnings (#2995)
b8413ddd5 [wpiutil] Add noexcept to timestamp static functions (#2994)
5d976b6e1 [glass] Load NetworkTableView settings on first draw (#2993)
2b4317452 Replace NOLINT(runtime/explicit) comments with NOLINT (NFC) (#2992)
1c3011ba4 [glass] Fix handling of "/" NetworkTables key (#2991)
574a42f3b [hal] Fix UnsafeManipulateDIO status check (#2987)
9005cd59e [wpilib] Clamp input voltage in sim classes (#2955)
dd494d4ab [glass] NetworkTablesModel::Update(): Avoid use-after-move (#2988)
7cca469a1 [wpimath] NormalizeAngle: Make inline, remove unnamed namespace (#2986)
2aed432b4 Add braces to C++ single-line loops and conditionals (NFC) (#2973)
0291a3ff5 [wpiutil] StringRef: Add noexcept to several constructors (#2984)
5d7315280 [wpimath] Update UnitsTest.cpp copyright (#2985)
254931b9a [wpimath] Remove LinearSystem from LinearSystemLoop (#2968)
aa89744c9 Update OtherVersions.md to include wpimath info (#2983)
1cda3f5ad [glass] Fix styleguide (#2976)
8f1f64ffb Remove year from file copyright message (NFC) (#2972)
2bc0a7795 [examples] Fix wpiformat warning about utility include (#2971)
4204da6ad [glass] Add application icon
7ac39b10f [wpigui] Add icon support
6b567e006 [wpimath] Add support for varying vision standard deviations in pose estimators (#2956)
df299d6ed [wpimath] Add UnscentedKalmanFilter::Correct() overload (#2966)
4e34f0523 [examples] Use ADXRS450_GyroSim class in simulation example (#2964)
9962f6fd7 [wpilib] Give Field2d a default Sendable name (#2953)
f9d492f4b [sim] GUI: Show "Other Devices" window by default (#2961)
a8bb2ef1c [sim] Fix ADXRS450_GyroSim and DutyCycleEncoderSim (#2963)
240c629cd [sim] Try to guess "Map Gamepad" setting (#2960)
952567dd3 [wpilibc] Add missing move constructors and assignment operators (#2959)
10b396b4c [sim] Various WebSockets fixes and enhancements (#2952)
699bbe21a [examples] Fix comments in Gearsbot to match implementation (NFC) (#2957)
27b67deca [glass] Add more widgets (#2947)
581b7ec55 [wpilib] Add option to flush NetworkTables every iterative loop
acfbb1a44 [ntcore] DispatcherBase::Flush: Use wpi::Now()
d85a6d8fe [ntcore] Reduce limit on flush and update rate to 5 ms
20fbb5c63 [sim] Fix stringop truncation warning from GCC 10 (#2945)
1051a06a7 [glass] Show NT timestamps in seconds (#2944)
98dfc2620 [glass] Fix plots (#2943)
1ba0a2ced [sim] GUI: Add keyboard virtual joystick support (#2940)
4afb13f98 [examples] Replace M_PI with wpi::math::pi (#2938)
b27d33675 [examples] Enhance Romi templates (#2931)
00b9ae77f [sim] Change default WS port number to 3300 (#2932)
65219f309 [examples] Update Field2d position in periodic() (#2928)
f78d1d434 [sim] Process WS Encoder reset internally (#2927)
941edca59 [hal] Add Java SimDeviceDataJNI.getSimDeviceName (#2924)
a699435ed [wpilibj] Fix FlywheelSim argument order in constructor (#2922)
66d641718 [examples] Add tasks to run Java examples (#2920)
558e37c41 [examples] Add simple differential drive simulation example (#2918)
4f40d991e [glass] Switch name of Glass back to glass (#2919)
549af9900 [build] Update native-utils to 2021.0.6 (#2914)
b33693009 [glass] Change basename of glass to Glass (#2915)
c9a0edfb8 [glass] Package macOS application bundle
2c5668af4 [wpigui] Add platform-specific preferences save
751dea32a [wpilibc] Try to work around ABI break introduced in #2901 (#2917)
cd8f4bfb1 [build] Package up msvc runtime into maven artifact (#2913)
a6cfcc686 [wpilibc] Move SendableChooser Doxygen comments to header (NFC) (#2911)
b8c4f603d [wpimath] Upgrade to Eigen 3.3.9 (#2910)
0075e4b39 [wpilibj] Fix NPE in Field2d (#2909)
125af556c [simulation] Fix halsim_gui ntcore and wpiutil deps (#2908)
963ad5c25 [wpilib] Add noise to Differential Drive simulator (#2903)
387f56cb7 [examples] Add Romi reference Java example and templates (#2905)
b3deda38c [examples] Zero motors on disabledInit() in sim physics examples (#2906)
2a5ca7745 [glass] Add glass: an application for display of robot data
727940d84 [wpilib] Move Field2d to SmartDashboard
8cd42478e [wpilib] SendableBuilder: Make GetTable() visible
c11d34b26 [command] Use addCommands in command group templates (#2900)
339d7445b [sim] Add HAL hooks for simulationPeriodic (#2881)
d16f05f2c [wpilib] Fix SmartDashboard update order (#2896)
5427b32a4 [wpiutil] unique_function: Restrict implicit conversion (#2899)
f73701239 [ntcore] Add missing SetDefault initializer_list functions (#2898)
f5a6fc070 [sim] Add initialized flag for all solenoids on a PCM (#2897)
bdf5ba91a [wpilibj] Fix typo in ElevatorSim (#2895)
bc8f33877 [wpilib] Add pose estimators (#2867)
3413bfc06 [wpilib] PIDController: Recompute the error in AtSetpoint() (#2822)
2056f0ce0 [wpilib] Fix bugs in Hatchbot examples (#2893)
5eb8cfd69 [wpilibc] Fix MatchDataSender (#2892)
e6a425448 [build] Delete test folders after tests execute (#2891)
d478ad00d [imgui] Allow usage of imgui_stdlib (#2889)
53eda861d [build] Add unit-testing infrastructure to examples (#2863)
cc1d86ba6 [sim] Add title to simulator GUI window (#2888)
f0528f00e [build] CMake: Use project-specific binary and source dirs (#2886)
5cd2ad124 [wpilibc] Add Color::operator!= (#2887)
6c00e7a90 [build] CI CMake: build with GUI enabled (#2884)
53170bbb5 Update roboRIO toolchain installation instructions (#2883)
467258e05 [sim] GUI: Add option to not zero disconnected joysticks (#2876)
129be23c9 Clarify JDK installation instructions in readme (#2882)
8e9290e86 [build] Add separate CMake setting for wpimath (#2885)
7cf5bebf8 [wpilibj] Cache NT writes from DriverStation (#2780)
f7f9087fb [command] Fix timing issue in RamseteCommand (#2871)
256e7904f [wpilibj] SimDeviceSim: Fix sim value changed callback (#2880)
c8ea1b6c3 [wpilib] Add function to adjust LQR controller gain for pure time delay (#2878)
2816b06c0 [sim] HAL_GetControlWord: Fully zero output (#2873)
4c695ea08 Add toolchain installation instructions to README (#2875)
a14d51806 [wpimath] DCMotor: fix doc typo (NFC) (#2868)
017097791 [build] CMake: build sim extensions as shared libs (#2866)
f61726b5a [build] Fix cmake-config files (#2865)
fc27fdac5 [wpilibc] Cache NT values from driver station (#2768)
47c59859e [sim] Make SimDevice callbacks synchronous (#2861)
6e76ab9c0 [build] Turn on WITH_GUI for Windows cmake CI
5f78b7670 [build] Set GLFW_INSTALL to OFF
5e0808c84 [wpigui] Fix Windows cmake build
508f05a47 [imgui] Fix typo in Windows CMake target sources
Change-Id: I1737b45965f31803a96676bedc7dc40e337aa321
git-subtree-dir: third_party/allwpilib
git-subtree-split: e473a00f9785f9949e5ced30901baeaf426d2fc9
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/wpimath/.styleguide b/wpimath/.styleguide
index c82141b..b9044c9 100644
--- a/wpimath/.styleguide
+++ b/wpimath/.styleguide
@@ -8,6 +8,10 @@
\.cpp$
}
+modifiableFileExclude {
+ \.patch$
+}
+
generatedFileExclude {
src/main/native/cpp/drake/
src/main/native/eigeninclude/
@@ -31,6 +35,7 @@
}
includeOtherLibs {
+ ^fmt/
^wpi/
}
diff --git a/wpimath/CMakeLists.txt b/wpimath/CMakeLists.txt
index a129376..2a4090f 100644
--- a/wpimath/CMakeLists.txt
+++ b/wpimath/CMakeLists.txt
@@ -11,11 +11,11 @@
find_package(Java REQUIRED)
find_package(JNI REQUIRED)
include(UseJava)
- set(CMAKE_JAVA_COMPILE_FLAGS "-Xlint:unchecked")
+ set(CMAKE_JAVA_COMPILE_FLAGS "-encoding" "UTF8" "-Xlint:unchecked")
- if(NOT EXISTS "${CMAKE_BINARY_DIR}/wpimath/thirdparty/ejml/ejml-simple-0.38.jar")
+ if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpimath/thirdparty/ejml/ejml-simple-0.38.jar")
set(BASE_URL "https://search.maven.org/remotecontent?filepath=")
- set(JAR_ROOT "${CMAKE_BINARY_DIR}/wpimath/thirdparty/ejml")
+ set(JAR_ROOT "${WPILIB_BINARY_DIR}/wpimath/thirdparty/ejml")
message(STATUS "Downloading EJML jarfiles...")
@@ -37,15 +37,15 @@
message(STATUS "All files downloaded.")
endif()
- file(GLOB EJML_JARS "${CMAKE_BINARY_DIR}/wpimath/thirdparty/ejml/*.jar")
- file(GLOB JACKSON_JARS "${CMAKE_BINARY_DIR}/wpiutil/thirdparty/jackson/*.jar")
+ file(GLOB EJML_JARS "${WPILIB_BINARY_DIR}/wpimath/thirdparty/ejml/*.jar")
+ file(GLOB JACKSON_JARS "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/jackson/*.jar")
set(CMAKE_JAVA_INCLUDE_PATH wpimath.jar ${EJML_JARS} ${JACKSON_JARS})
- execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/wpimath/generate_numbers.py ${CMAKE_BINARY_DIR}/wpimath RESULT_VARIABLE generateResult)
+ execute_process(COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/generate_numbers.py ${WPILIB_BINARY_DIR}/wpimath RESULT_VARIABLE generateResult)
if(NOT (generateResult EQUAL "0"))
# Try python
- execute_process(COMMAND python ${CMAKE_SOURCE_DIR}/wpimath/generate_numbers.py ${CMAKE_BINARY_DIR}/wpimath RESULT_VARIABLE generateResult)
+ execute_process(COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/generate_numbers.py ${WPILIB_BINARY_DIR}/wpimath RESULT_VARIABLE generateResult)
if(NOT (generateResult EQUAL "0"))
message(FATAL_ERROR "python and python3 generate_numbers.py failed")
endif()
@@ -53,7 +53,7 @@
set(CMAKE_JNI_TARGET true)
- file(GLOB_RECURSE JAVA_SOURCES src/main/java/*.java ${CMAKE_BINARY_DIR}/wpimath/generated/*.java)
+ file(GLOB_RECURSE JAVA_SOURCES src/main/java/*.java ${WPILIB_BINARY_DIR}/wpimath/generated/*.java)
if(${CMAKE_VERSION} VERSION_LESS "3.11.0")
set(CMAKE_JAVA_COMPILE_FLAGS "-h" "${CMAKE_CURRENT_BINARY_DIR}/jniheaders")
@@ -92,10 +92,13 @@
file(GLOB_RECURSE wpimath_native_src src/main/native/cpp/*.cpp)
list(REMOVE_ITEM wpimath_native_src ${wpimath_jni_src})
+set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS FALSE)
add_library(wpimath ${wpimath_native_src})
+set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
set_target_properties(wpimath PROPERTIES DEBUG_POSTFIX "d")
set_property(TARGET wpimath PROPERTY FOLDER "libraries")
+target_compile_definitions(wpimath PRIVATE WPILIB_EXPORTS)
target_compile_features(wpimath PUBLIC cxx_std_17)
if (MSVC)
@@ -106,7 +109,7 @@
if (NOT USE_VCPKG_EIGEN)
install(DIRECTORY src/main/native/eigeninclude/ DESTINATION "${include_dest}/wpimath")
- target_include_directories(wpimath PUBLIC
+ target_include_directories(wpimath SYSTEM PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/src/main/native/eigeninclude>
$<INSTALL_INTERFACE:${include_dest}/wpimath>)
else()
@@ -131,8 +134,8 @@
set (wpimath_config_dir share/wpimath)
endif()
-configure_file(wpimath-config.cmake.in ${CMAKE_BINARY_DIR}/wpimath-config.cmake )
-install(FILES ${CMAKE_BINARY_DIR}/wpimath-config.cmake DESTINATION ${wpimath_config_dir})
+configure_file(wpimath-config.cmake.in ${WPILIB_BINARY_DIR}/wpimath-config.cmake )
+install(FILES ${WPILIB_BINARY_DIR}/wpimath-config.cmake DESTINATION ${wpimath_config_dir})
install(EXPORT wpimath DESTINATION ${wpimath_config_dir})
if (WITH_TESTS)
diff --git a/wpimath/build.gradle b/wpimath/build.gradle
index c023a1a..79db4bc 100644
--- a/wpimath/build.gradle
+++ b/wpimath/build.gradle
@@ -1,3 +1,6 @@
+import com.hubspot.jinjava.Jinjava;
+import com.hubspot.jinjava.JinjavaConfig;
+
ext {
useJava = true
useCpp = true
@@ -5,24 +8,11 @@
groupId = 'edu.wpi.first.wpimath'
nativeName = 'wpimath'
- devMain = 'edu.wpi.first.wpiutil.math.DevMain'
+ devMain = 'edu.wpi.first.math.DevMain'
}
apply from: "${rootDir}/shared/jni/setupBuild.gradle"
-nativeUtils.exportsConfigs {
- wpimath {
- x86ExcludeSymbols = ['_CT??_R0?AV_System_error', '_CT??_R0?AVexception', '_CT??_R0?AVfailure',
- '_CT??_R0?AVruntime_error', '_CT??_R0?AVsystem_error', '_CTA5?AVfailure',
- '_TI5?AVfailure', '_CT??_R0?AVout_of_range', '_CTA3?AVout_of_range',
- '_TI3?AVout_of_range', '_CT??_R0?AVbad_cast']
- x64ExcludeSymbols = ['_CT??_R0?AV_System_error', '_CT??_R0?AVexception', '_CT??_R0?AVfailure',
- '_CT??_R0?AVruntime_error', '_CT??_R0?AVsystem_error', '_CTA5?AVfailure',
- '_TI5?AVfailure', '_CT??_R0?AVout_of_range', '_CTA3?AVout_of_range',
- '_TI3?AVout_of_range', '_CT??_R0?AVbad_cast']
- }
-}
-
cppHeadersZip {
from('src/main/native/eigeninclude') {
into '/'
@@ -48,11 +38,10 @@
api "com.fasterxml.jackson.core:jackson-databind:2.10.0"
}
-def wpilibNumberFileInput = file("src/generate/GenericNumber.java.in")
-def natFileInput = file("src/generate/Nat.java.in")
-def natGetterInput = file("src/generate/NatGetter.java.in")
-def wpilibNumberFileOutputDir = file("$buildDir/generated/java/edu/wpi/first/wpiutil/math/numbers")
-def wpilibNatFileOutput = file("$buildDir/generated/java/edu/wpi/first/wpiutil/math/Nat.java")
+def wpilibNumberFileInput = file("src/generate/GenericNumber.java.jinja")
+def natFileInput = file("src/generate/Nat.java.jinja")
+def wpilibNumberFileOutputDir = file("$buildDir/generated/java/edu/wpi/first/math/numbers")
+def wpilibNatFileOutput = file("$buildDir/generated/java/edu/wpi/first/math/Nat.java")
def maxNum = 20
task generateNumbers() {
@@ -68,10 +57,17 @@
}
wpilibNumberFileOutputDir.mkdirs()
+ def config = new JinjavaConfig()
+ def jinjava = new Jinjava(config)
+
+ def template = wpilibNumberFileInput.text
+
for(i in 0..maxNum) {
def outputFile = new File(wpilibNumberFileOutputDir, "N${i}.java")
- def read = wpilibNumberFileInput.text.replace('${num}', i.toString())
- outputFile.write(read)
+ def replacements = new HashMap<String,?>()
+ replacements.put("num", i)
+ def output = jinjava.render(template, replacements)
+ outputFile.write(output)
}
}
}
@@ -79,7 +75,7 @@
task generateNat() {
description = "Generates Nat.java"
group = "WPILib"
- inputs.files([natFileInput, natGetterInput])
+ inputs.file natFileInput
outputs.file wpilibNatFileOutput
dependsOn generateNumbers
@@ -88,19 +84,16 @@
wpilibNatFileOutput.delete()
}
- def template = natFileInput.text + "\n"
+ def config = new JinjavaConfig()
+ def jinjava = new Jinjava(config)
- def importsString = "";
+ def template = natFileInput.text
- for(i in 0..maxNum) {
- importsString += "import edu.wpi.first.wpiutil.math.numbers.N${i};\n"
- template += natGetterInput.text.replace('${num}', i.toString()) + "\n"
- }
- template += "}\n" // Close the class body
+ def replacements = new HashMap<String,?>()
+ replacements.put("nums", 0..maxNum)
- template = template.replace('{{REPLACEWITHIMPORTS}}', importsString)
-
- wpilibNatFileOutput.write(template)
+ def output = jinjava.render(template, replacements)
+ wpilibNatFileOutput.write(output)
}
}
diff --git a/wpimath/generate_numbers.py b/wpimath/generate_numbers.py
index 701f3e6..c52ddb4 100644
--- a/wpimath/generate_numbers.py
+++ b/wpimath/generate_numbers.py
@@ -1,5 +1,26 @@
+# Copyright (c) FIRST and other WPILib contributors.
+# Open Source Software; you can modify and/or share it under the terms of
+# the WPILib BSD license file in the root directory of this project.
+
import os
import sys
+from jinja2 import Environment, FileSystemLoader
+
+
+def output(outPath, outfn, contents):
+ if not os.path.exists(outPath):
+ os.makedirs(outPath)
+
+ outpathname = f"{outPath}/{outfn}"
+
+ if os.path.exists(outpathname):
+ with open(outpathname, "r") as f:
+ if f.read() == contents:
+ return
+
+ # File either doesn't exist or has different contents
+ with open(outpathname, "w") as f:
+ f.write(contents)
def main():
@@ -8,51 +29,21 @@
dirname, _ = os.path.split(os.path.abspath(__file__))
cmake_binary_dir = sys.argv[1]
- with open(f"{dirname}/src/generate/GenericNumber.java.in",
- "r") as templateFile:
- template = templateFile.read()
- rootPath = f"{cmake_binary_dir}/generated/main/java/edu/wpi/first/wpiutil/math/numbers"
+ env = Environment(loader=FileSystemLoader(f"{dirname}/src/generate"),
+ autoescape=False,
+ keep_trailing_newline=True)
- if not os.path.exists(rootPath):
- os.makedirs(rootPath)
+ template = env.get_template("GenericNumber.java.jinja")
+ rootPath = f"{cmake_binary_dir}/generated/main/java/edu/wpi/first/math/numbers"
- for i in range(MAX_NUM + 1):
- new_contents = template.replace("${num}", str(i))
+ for i in range(MAX_NUM + 1):
+ contents = template.render(num=i)
+ output(rootPath, f"N{i}.java", contents)
- if os.path.exists(f"{rootPath}/N{i}.java"):
- with open(f"{rootPath}/N{i}.java", "r") as f:
- if f.read() == new_contents:
- continue
-
- # File either doesn't exist or has different contents
- with open(f"{rootPath}/N{i}.java", "w") as f:
- f.write(new_contents)
-
- with open(f"{dirname}/src/generate/Nat.java.in", "r") as templateFile:
- template = templateFile.read()
- outputPath = f"{cmake_binary_dir}/generated/main/java/edu/wpi/first/wpiutil/math/Nat.java"
- with open(f"{dirname}/src/generate/NatGetter.java.in",
- "r") as getterFile:
- getter = getterFile.read()
-
- importsString = ""
-
- for i in range(MAX_NUM + 1):
- importsString += f"import edu.wpi.first.wpiutil.math.numbers.N{i};\n"
- template += getter.replace("${num}", str(i))
-
- template += "}\n"
-
- template = template.replace('{{REPLACEWITHIMPORTS}}', importsString)
-
- if os.path.exists(outputPath):
- with open(outputPath, "r") as f:
- if f.read() == template:
- return 0
-
- # File either doesn't exist or has different contents
- with open(outputPath, "w") as f:
- f.write(template)
+ template = env.get_template("Nat.java.jinja")
+ rootPath = f"{cmake_binary_dir}/generated/main/java/edu/wpi/first/math"
+ contents = template.render(nums=range(MAX_NUM + 1))
+ output(rootPath, "Nat.java", contents)
if __name__ == "__main__":
diff --git a/wpimath/src/dev/java/edu/wpi/first/math/DevMain.java b/wpimath/src/dev/java/edu/wpi/first/math/DevMain.java
new file mode 100644
index 0000000..c7b779e
--- /dev/null
+++ b/wpimath/src/dev/java/edu/wpi/first/math/DevMain.java
@@ -0,0 +1,15 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+public final class DevMain {
+ /** Main entry point. */
+ public static void main(String[] args) {
+ System.out.println("Hello World!");
+ System.out.println(MathUtil.angleModulus(-5.0));
+ }
+
+ private DevMain() {}
+}
diff --git a/wpimath/src/dev/java/edu/wpi/first/wpiutil/math/DevMain.java b/wpimath/src/dev/java/edu/wpi/first/wpiutil/math/DevMain.java
deleted file mode 100644
index 9a5d378..0000000
--- a/wpimath/src/dev/java/edu/wpi/first/wpiutil/math/DevMain.java
+++ /dev/null
@@ -1,21 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2017-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-public final class DevMain {
- /**
- * Main entry point.
- */
- public static void main(String[] args) {
- System.out.println("Hello World!");
- System.out.println(MathUtil.normalizeAngle(-5.0));
- }
-
- private DevMain() {
- }
-}
diff --git a/wpimath/src/dev/native/cpp/main.cpp b/wpimath/src/dev/native/cpp/main.cpp
index 030c8f9..54952b7 100644
--- a/wpimath/src/dev/native/cpp/main.cpp
+++ b/wpimath/src/dev/native/cpp/main.cpp
@@ -1,12 +1,10 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2017-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-#include <iostream>
+#include <fmt/core.h>
+#include <wpi/numbers>
-#include <wpi/math>
-
-int main() { std::cout << wpi::math::pi << std::endl; }
+int main() {
+ fmt::print("{}\n", wpi::numbers::pi);
+}
diff --git a/wpimath/src/generate/GenericNumber.java.in b/wpimath/src/generate/GenericNumber.java.in
deleted file mode 100644
index 5a36582..0000000
--- a/wpimath/src/generate/GenericNumber.java.in
+++ /dev/null
@@ -1,34 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math.numbers;
-
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-
-/**
- * A class representing the number ${num}.
-*/
-public final class N${num} extends Num implements Nat<N${num}> {
- private N${num}() {
- }
-
- /**
- * The integer this class represents.
- *
- * @return The literal number ${num}.
- */
- @Override
- public int getNum() {
- return ${num};
- }
-
- /**
- * The singleton instance of this class.
- */
- public static final N${num} instance = new N${num}();
-}
diff --git a/wpimath/src/generate/GenericNumber.java.jinja b/wpimath/src/generate/GenericNumber.java.jinja
new file mode 100644
index 0000000..5e4be85
--- /dev/null
+++ b/wpimath/src/generate/GenericNumber.java.jinja
@@ -0,0 +1,31 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.numbers;
+
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+
+/**
+ * A class representing the number {{ num }}.
+*/
+public final class N{{ num }} extends Num implements Nat<N{{ num }}> {
+ private N{{ num }}() {
+ }
+
+ /**
+ * The integer this class represents.
+ *
+ * @return The literal number {{ num }}.
+ */
+ @Override
+ public int getNum() {
+ return {{ num }};
+ }
+
+ /**
+ * The singleton instance of this class.
+ */
+ public static final N{{ num }} instance = new N{{ num }}();
+}
diff --git a/wpimath/src/generate/Nat.java.in b/wpimath/src/generate/Nat.java.in
deleted file mode 100644
index 666bd1c..0000000
--- a/wpimath/src/generate/Nat.java.in
+++ /dev/null
@@ -1,27 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-//CHECKSTYLE.OFF: ImportOrder
-{{REPLACEWITHIMPORTS}}
-//CHECKSTYLE.ON
-
-/**
- * A natural number expressed as a java class.
- * The counterpart to {@link Num} that should be used as a concrete value.
- *
- * @param <T> The {@link Num} this represents.
- */
-@SuppressWarnings({"MethodName", "unused"})
-public interface Nat<T extends Num> {
- /**
- * The number this interface represents.
- *
- * @return The number backing this value.
- */
- int getNum();
diff --git a/wpimath/src/generate/Nat.java.jinja b/wpimath/src/generate/Nat.java.jinja
new file mode 100644
index 0000000..31451d2
--- /dev/null
+++ b/wpimath/src/generate/Nat.java.jinja
@@ -0,0 +1,32 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+//CHECKSTYLE.OFF: ImportOrder
+{% for num in nums %}
+import edu.wpi.first.math.numbers.N{{ num }};
+{%- endfor %}
+//CHECKSTYLE.ON
+
+/**
+ * A natural number expressed as a java class.
+ * The counterpart to {@link Num} that should be used as a concrete value.
+ *
+ * @param <T> The {@link Num} this represents.
+ */
+@SuppressWarnings({"MethodName", "unused"})
+public interface Nat<T extends Num> {
+ /**
+ * The number this interface represents.
+ *
+ * @return The number backing this value.
+ */
+ int getNum();
+{% for num in nums %}
+ static Nat<N{{ num }}> N{{ num }}() {
+ return N{{ num }}.instance;
+ }
+{% endfor %}
+}
diff --git a/wpimath/src/generate/NatGetter.java.in b/wpimath/src/generate/NatGetter.java.in
deleted file mode 100644
index d268fab..0000000
--- a/wpimath/src/generate/NatGetter.java.in
+++ /dev/null
@@ -1,3 +0,0 @@
- static Nat<N${num}> N${num}() {
- return N${num}.instance;
- }
diff --git a/wpimath/src/main/java/edu/wpi/first/math/Drake.java b/wpimath/src/main/java/edu/wpi/first/math/Drake.java
index 4d66102..766f7e9 100644
--- a/wpimath/src/main/java/edu/wpi/first/math/Drake.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/Drake.java
@@ -1,20 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Num;
-
public final class Drake {
- private Drake() {
- }
+ private Drake() {}
/**
* Solves the discrete alegebraic Riccati equation.
@@ -27,33 +20,101 @@
*/
@SuppressWarnings({"LocalVariableName", "ParameterName"})
public static SimpleMatrix discreteAlgebraicRiccatiEquation(
- SimpleMatrix A,
- SimpleMatrix B,
- SimpleMatrix Q,
- SimpleMatrix R) {
+ SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) {
var S = new SimpleMatrix(A.numRows(), A.numCols());
- WPIMathJNI.discreteAlgebraicRiccatiEquation(A.getDDRM().getData(), B.getDDRM().getData(),
- Q.getDDRM().getData(), R.getDDRM().getData(), A.numCols(), B.numCols(),
- S.getDDRM().getData());
+ WPIMathJNI.discreteAlgebraicRiccatiEquation(
+ A.getDDRM().getData(),
+ B.getDDRM().getData(),
+ Q.getDDRM().getData(),
+ R.getDDRM().getData(),
+ A.numCols(),
+ B.numCols(),
+ S.getDDRM().getData());
return S;
}
/**
* Solves the discrete alegebraic Riccati equation.
*
+ * @param <States> Number of states.
+ * @param <Inputs> Number of inputs.
+ * @param A System matrix.
+ * @param B Input matrix.
+ * @param Q State cost matrix.
+ * @param R Input cost matrix.
+ * @return Solution of DARE.
+ */
+ @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+ public static <States extends Num, Inputs extends Num>
+ Matrix<States, States> discreteAlgebraicRiccatiEquation(
+ Matrix<States, States> A,
+ Matrix<States, Inputs> B,
+ Matrix<States, States> Q,
+ Matrix<Inputs, Inputs> R) {
+ return new Matrix<>(
+ discreteAlgebraicRiccatiEquation(
+ A.getStorage(), B.getStorage(), Q.getStorage(), R.getStorage()));
+ }
+
+ /**
+ * Solves the discrete alegebraic Riccati equation.
+ *
* @param A System matrix.
* @param B Input matrix.
* @param Q State cost matrix.
* @param R Input cost matrix.
+ * @param N State-input cross-term cost matrix.
+ * @return Solution of DARE.
+ */
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ public static SimpleMatrix discreteAlgebraicRiccatiEquation(
+ SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R, SimpleMatrix N) {
+ // See
+ // https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
+ // for the change of variables used here.
+ var scrA = A.minus(B.mult(R.solve(N.transpose())));
+ var scrQ = Q.minus(N.mult(R.solve(N.transpose())));
+
+ var S = new SimpleMatrix(A.numRows(), A.numCols());
+ WPIMathJNI.discreteAlgebraicRiccatiEquation(
+ scrA.getDDRM().getData(),
+ B.getDDRM().getData(),
+ scrQ.getDDRM().getData(),
+ R.getDDRM().getData(),
+ A.numCols(),
+ B.numCols(),
+ S.getDDRM().getData());
+ return S;
+ }
+
+ /**
+ * Solves the discrete alegebraic Riccati equation.
+ *
+ * @param <States> Number of states.
+ * @param <Inputs> Number of inputs.
+ * @param A System matrix.
+ * @param B Input matrix.
+ * @param Q State cost matrix.
+ * @param R Input cost matrix.
+ * @param N State-input cross-term cost matrix.
* @return Solution of DARE.
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public static <States extends Num, Inputs extends Num> Matrix<States, States>
- discreteAlgebraicRiccatiEquation(Matrix<States, States> A,
- Matrix<States, Inputs> B,
- Matrix<States, States> Q,
- Matrix<Inputs, Inputs> R) {
- return new Matrix<>(discreteAlgebraicRiccatiEquation(A.getStorage(), B.getStorage(),
- Q.getStorage(), R.getStorage()));
+ public static <States extends Num, Inputs extends Num>
+ Matrix<States, States> discreteAlgebraicRiccatiEquation(
+ Matrix<States, States> A,
+ Matrix<States, Inputs> B,
+ Matrix<States, States> Q,
+ Matrix<Inputs, Inputs> R,
+ Matrix<States, Inputs> N) {
+ // See
+ // https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
+ // for the change of variables used here.
+ var scrA = A.minus(B.times(R.solve(N.transpose())));
+ var scrQ = Q.minus(N.times(R.solve(N.transpose())));
+
+ return new Matrix<>(
+ discreteAlgebraicRiccatiEquation(
+ scrA.getStorage(), B.getStorage(), scrQ.getStorage(), R.getStorage()));
}
}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/MatBuilder.java b/wpimath/src/main/java/edu/wpi/first/math/MatBuilder.java
new file mode 100644
index 0000000..e5b1952
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/MatBuilder.java
@@ -0,0 +1,53 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+import java.util.Objects;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * A class for constructing arbitrary RxC matrices.
+ *
+ * @param <R> The number of rows of the desired matrix.
+ * @param <C> The number of columns of the desired matrix.
+ */
+public class MatBuilder<R extends Num, C extends Num> {
+ final Nat<R> m_rows;
+ final Nat<C> m_cols;
+
+ /**
+ * Fills the matrix with the given data, encoded in row major form. (The matrix is filled row by
+ * row, left to right with the given data).
+ *
+ * @param data The data to fill the matrix with.
+ * @return The constructed matrix.
+ */
+ @SuppressWarnings("LineLength")
+ public final Matrix<R, C> fill(double... data) {
+ if (Objects.requireNonNull(data).length != this.m_rows.getNum() * this.m_cols.getNum()) {
+ throw new IllegalArgumentException(
+ "Invalid matrix data provided. Wanted "
+ + this.m_rows.getNum()
+ + " x "
+ + this.m_cols.getNum()
+ + " matrix, but got "
+ + data.length
+ + " elements");
+ } else {
+ return new Matrix<>(new SimpleMatrix(this.m_rows.getNum(), this.m_cols.getNum(), true, data));
+ }
+ }
+
+ /**
+ * Creates a new {@link MatBuilder} with the given dimensions.
+ *
+ * @param rows The number of rows of the matrix.
+ * @param cols The number of columns of the matrix.
+ */
+ public MatBuilder(Nat<R> rows, Nat<C> cols) {
+ this.m_rows = Objects.requireNonNull(rows);
+ this.m_cols = Objects.requireNonNull(cols);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/MathShared.java b/wpimath/src/main/java/edu/wpi/first/math/MathShared.java
index 168dbb5..483dad3 100644
--- a/wpimath/src/main/java/edu/wpi/first/math/MathShared.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/MathShared.java
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
@@ -12,6 +9,7 @@
* Report an error.
*
* @param error the error to set
+ * @param stackTrace array of stacktrace elements
*/
void reportError(String error, StackTraceElement[] stackTrace);
diff --git a/wpimath/src/main/java/edu/wpi/first/math/MathSharedStore.java b/wpimath/src/main/java/edu/wpi/first/math/MathSharedStore.java
index a4c8425..0dbc03d 100644
--- a/wpimath/src/main/java/edu/wpi/first/math/MathSharedStore.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/MathSharedStore.java
@@ -1,38 +1,37 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2018-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
public final class MathSharedStore {
private static MathShared mathShared;
- private MathSharedStore() {
- }
+ private MathSharedStore() {}
/**
- * get the MathShared object.
+ * Get the MathShared object.
+ *
+ * @return The MathShared object.
*/
public static synchronized MathShared getMathShared() {
if (mathShared == null) {
- mathShared = new MathShared() {
- @Override
- public void reportError(String error, StackTraceElement[] stackTrace) {
- }
+ mathShared =
+ new MathShared() {
+ @Override
+ public void reportError(String error, StackTraceElement[] stackTrace) {}
- @Override
- public void reportUsage(MathUsageId id, int count) {
- }
- };
+ @Override
+ public void reportUsage(MathUsageId id, int count) {}
+ };
}
return mathShared;
}
/**
- * set the MathShared object.
+ * Set the MathShared object.
+ *
+ * @param shared The MathShared object.
*/
public static synchronized void setMathShared(MathShared shared) {
mathShared = shared;
@@ -42,6 +41,7 @@
* Report an error.
*
* @param error the error to set
+ * @param stackTrace array of stacktrace elements
*/
public static void reportError(String error, StackTraceElement[] stackTrace) {
getMathShared().reportError(error, stackTrace);
diff --git a/wpimath/src/main/java/edu/wpi/first/math/MathUsageId.java b/wpimath/src/main/java/edu/wpi/first/math/MathUsageId.java
index ed95e24..a3cc299 100644
--- a/wpimath/src/main/java/edu/wpi/first/math/MathUsageId.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/MathUsageId.java
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
@@ -15,5 +12,7 @@
kFilter_Linear,
kOdometry_DifferentialDrive,
kOdometry_SwerveDrive,
- kOdometry_MecanumDrive
+ kOdometry_MecanumDrive,
+ kController_PIDController2,
+ kController_ProfiledPIDController,
}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/MathUtil.java b/wpimath/src/main/java/edu/wpi/first/math/MathUtil.java
new file mode 100644
index 0000000..791de8a
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/MathUtil.java
@@ -0,0 +1,87 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+public final class MathUtil {
+ private MathUtil() {
+ throw new AssertionError("utility class");
+ }
+
+ /**
+ * Returns value clamped between low and high boundaries.
+ *
+ * @param value Value to clamp.
+ * @param low The lower boundary to which to clamp value.
+ * @param high The higher boundary to which to clamp value.
+ * @return The clamped value.
+ */
+ public static int clamp(int value, int low, int high) {
+ return Math.max(low, Math.min(value, high));
+ }
+
+ /**
+ * Returns value clamped between low and high boundaries.
+ *
+ * @param value Value to clamp.
+ * @param low The lower boundary to which to clamp value.
+ * @param high The higher boundary to which to clamp value.
+ * @return The clamped value.
+ */
+ public static double clamp(double value, double low, double high) {
+ return Math.max(low, Math.min(value, high));
+ }
+
+ /**
+ * Returns 0.0 if the given value is within the specified range around zero. The remaining range
+ * between the deadband and 1.0 is scaled from 0.0 to 1.0.
+ *
+ * @param value Value to clip.
+ * @param deadband Range around zero.
+ * @return The value after the deadband is applied.
+ */
+ public static double applyDeadband(double value, double deadband) {
+ if (Math.abs(value) > deadband) {
+ if (value > 0.0) {
+ return (value - deadband) / (1.0 - deadband);
+ } else {
+ return (value + deadband) / (1.0 - deadband);
+ }
+ } else {
+ return 0.0;
+ }
+ }
+
+ /**
+ * Returns modulus of input.
+ *
+ * @param input Input value to wrap.
+ * @param minimumInput The minimum value expected from the input.
+ * @param maximumInput The maximum value expected from the input.
+ * @return The wrapped value.
+ */
+ public static double inputModulus(double input, double minimumInput, double maximumInput) {
+ double modulus = maximumInput - minimumInput;
+
+ // Wrap input if it's above the maximum input
+ int numMax = (int) ((input - minimumInput) / modulus);
+ input -= numMax * modulus;
+
+ // Wrap input if it's below the minimum input
+ int numMin = (int) ((input - maximumInput) / modulus);
+ input -= numMin * modulus;
+
+ return input;
+ }
+
+ /**
+ * Wraps an angle to the range -pi to pi radians.
+ *
+ * @param angleRadians Angle to wrap in radians.
+ * @return The wrapped angle.
+ */
+ public static double angleModulus(double angleRadians) {
+ return inputModulus(angleRadians, -Math.PI, Math.PI);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java b/wpimath/src/main/java/edu/wpi/first/math/Matrix.java
similarity index 70%
rename from wpimath/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java
rename to wpimath/src/main/java/edu/wpi/first/math/Matrix.java
index a87b98a..113758b 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/Matrix.java
@@ -1,14 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpiutil.math;
+package edu.wpi.first.math;
+import edu.wpi.first.math.numbers.N1;
import java.util.Objects;
-
import org.ejml.MatrixDimensionException;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
@@ -18,9 +15,6 @@
import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.math.WPIMathJNI;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
/**
* A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices.
*
@@ -29,31 +23,28 @@
* @param <R> The number of rows in this matrix.
* @param <C> The number of columns in this matrix.
*/
-@SuppressWarnings("PMD.ExcessivePublicCount")
public class Matrix<R extends Num, C extends Num> {
protected final SimpleMatrix m_storage;
/**
* Constructs an empty zero matrix of the given dimensions.
*
- * @param rows The number of rows of the matrix.
+ * @param rows The number of rows of the matrix.
* @param columns The number of columns of the matrix.
*/
public Matrix(Nat<R> rows, Nat<C> columns) {
- this.m_storage = new SimpleMatrix(
- Objects.requireNonNull(rows).getNum(),
- Objects.requireNonNull(columns).getNum()
- );
+ this.m_storage =
+ new SimpleMatrix(
+ Objects.requireNonNull(rows).getNum(), Objects.requireNonNull(columns).getNum());
}
/**
- * Constructs a new {@link Matrix} with the given storage.
- * Caller should make sure that the provided generic bounds match
- * the shape of the provided {@link Matrix}.
+ * Constructs a new {@link Matrix} with the given storage. Caller should make sure that the
+ * provided generic bounds match the shape of the provided {@link Matrix}.
*
- * <p>NOTE:It is not recommend to use this constructor unless the
- * {@link SimpleMatrix} API is absolutely necessary due to the desired
- * function not being accessible through the {@link Matrix} wrapper.
+ * <p>NOTE:It is not recommend to use this constructor unless the {@link SimpleMatrix} API is
+ * absolutely necessary due to the desired function not being accessible through the {@link
+ * Matrix} wrapper.
*
* @param storage The {@link SimpleMatrix} to back this value.
*/
@@ -73,10 +64,9 @@
/**
* Gets the underlying {@link SimpleMatrix} that this {@link Matrix} wraps.
*
- * <p>NOTE:The use of this method is heavily discouraged as this removes any
- * guarantee of type safety. This should only be called if the {@link SimpleMatrix}
- * API is absolutely necessary due to the desired function not being accessible through
- * the {@link Matrix} wrapper.
+ * <p>NOTE:The use of this method is heavily discouraged as this removes any guarantee of type
+ * safety. This should only be called if the {@link SimpleMatrix} API is absolutely necessary due
+ * to the desired function not being accessible through the {@link Matrix} wrapper.
*
* @return The underlying {@link SimpleMatrix} storage.
*/
@@ -116,8 +106,8 @@
/**
* Sets the value at the given indices.
*
- * @param row The row of the element.
- * @param col The column of the element.
+ * @param row The row of the element.
+ * @param col The column of the element.
* @param value The value to insert at the given location.
*/
public final void set(int row, int col, double value) {
@@ -131,22 +121,19 @@
* @param val The row vector to set the given row to.
*/
public final void setRow(int row, Matrix<N1, C> val) {
- this.m_storage.setRow(row, 0,
- Objects.requireNonNull(val).m_storage.getDDRM().getData());
+ this.m_storage.setRow(row, 0, Objects.requireNonNull(val).m_storage.getDDRM().getData());
}
/**
* Sets a column to a given column vector.
*
* @param column The column to set.
- * @param val The column vector to set the given row to.
+ * @param val The column vector to set the given row to.
*/
public final void setColumn(int column, Matrix<R, N1> val) {
- this.m_storage.setColumn(column, 0,
- Objects.requireNonNull(val).m_storage.getDDRM().getData());
+ this.m_storage.setColumn(column, 0, Objects.requireNonNull(val).m_storage.getDDRM().getData());
}
-
/**
* Sets all the elements in "this" matrix equal to the specified value.
*
@@ -159,8 +146,8 @@
/**
* Returns the diagonal elements inside a vector or square matrix.
*
- * <p>If "this" {@link Matrix} is a vector then a square matrix is returned. If a "this"
- * {@link Matrix} is a matrix then a vector of diagonal elements is returned.
+ * <p>If "this" {@link Matrix} is a vector then a square matrix is returned. If a "this" {@link
+ * Matrix} is a matrix then a vector of diagonal elements is returned.
*
* @return The diagonal elements inside a vector or a square matrix.
*/
@@ -186,7 +173,6 @@
return CommonOps_DDRM.elementMaxAbs(this.m_storage.getDDRM());
}
-
/**
* Returns the smallest element of this matrix.
*
@@ -208,12 +194,12 @@
/**
* Multiplies this matrix with another that has C rows.
*
- * <p>As matrix multiplication is only defined if the number of columns
- * in the first matrix matches the number of rows in the second,
- * this operation will fail to compile under any other circumstances.
+ * <p>As matrix multiplication is only defined if the number of columns in the first matrix
+ * matches the number of rows in the second, this operation will fail to compile under any other
+ * circumstances.
*
* @param other The other matrix to multiply by.
- * @param <C2> The number of columns in the second matrix.
+ * @param <C2> The number of columns in the second matrix.
* @return The result of the matrix multiplication between "this" and the given matrix.
*/
public final <C2 extends Num> Matrix<R, C2> times(Matrix<C, C2> other) {
@@ -231,12 +217,11 @@
}
/**
- * Returns a matrix which is the result of an element by element multiplication of
- * "this" and other.
+ * Returns a matrix which is the result of an element by element multiplication of "this" and
+ * other.
*
* <p>c<sub>i,j</sub> = a<sub>i,j</sub>*other<sub>i,j</sub>
*
- *
* @param other The other {@link Matrix} to preform element multiplication on.
* @return The element by element multiplication of "this" and other.
*/
@@ -254,7 +239,6 @@
return new Matrix<>(this.m_storage.minus(value));
}
-
/**
* Subtracts the given matrix from this matrix.
*
@@ -265,7 +249,6 @@
return new Matrix<>(this.m_storage.minus(Objects.requireNonNull(value).m_storage));
}
-
/**
* Adds the given value to all the elements of this matrix.
*
@@ -307,7 +290,7 @@
}
/**
- * Calculates the transpose, M^T of this matrix.
+ * Calculates the transpose, Mᵀ of this matrix.
*
* @return The transpose matrix.
*/
@@ -315,7 +298,6 @@
return new Matrix<>(this.m_storage.transpose());
}
-
/**
* Returns a copy of this matrix.
*
@@ -325,7 +307,6 @@
return new Matrix<>(this.m_storage.copy());
}
-
/**
* Returns the inverse matrix of "this" matrix.
*
@@ -339,9 +320,10 @@
/**
* Returns the solution x to the equation Ax = b, where A is "this" matrix.
*
- * <p>The matrix equation could also be written as x = A<sup>-1</sup>b. Where the
- * pseudo inverse is used if A is not square.
+ * <p>The matrix equation could also be written as x = A<sup>-1</sup>b. Where the pseudo inverse
+ * is used if A is not square.
*
+ * @param <C2> Columns in b.
* @param b The right-hand side of the equation to solve.
* @return The solution to the linear system.
*/
@@ -351,20 +333,50 @@
}
/**
- * Computes the matrix exponential using Eigen's solver.
- * This method only works for square matrices, and will
- * otherwise throw an {@link MatrixDimensionException}.
+ * Computes the matrix exponential using Eigen's solver. This method only works for square
+ * matrices, and will otherwise throw an {@link MatrixDimensionException}.
*
* @return The exponential of A.
*/
public final Matrix<R, C> exp() {
if (this.getNumRows() != this.getNumCols()) {
- throw new MatrixDimensionException("Non-square matrices cannot be exponentiated! "
- + "This matrix is " + this.getNumRows() + " x " + this.getNumCols());
+ throw new MatrixDimensionException(
+ "Non-square matrices cannot be exponentiated! "
+ + "This matrix is "
+ + this.getNumRows()
+ + " x "
+ + this.getNumCols());
}
Matrix<R, C> toReturn = new Matrix<>(new SimpleMatrix(this.getNumRows(), this.getNumCols()));
- WPIMathJNI.exp(this.m_storage.getDDRM().getData(), this.getNumRows(),
- toReturn.m_storage.getDDRM().getData());
+ WPIMathJNI.exp(
+ this.m_storage.getDDRM().getData(),
+ this.getNumRows(),
+ toReturn.m_storage.getDDRM().getData());
+ return toReturn;
+ }
+
+ /**
+ * Computes the matrix power using Eigen's solver. This method only works for square matrices, and
+ * will otherwise throw an {@link MatrixDimensionException}.
+ *
+ * @param exponent The exponent.
+ * @return The exponential of A.
+ */
+ public final Matrix<R, C> pow(double exponent) {
+ if (this.getNumRows() != this.getNumCols()) {
+ throw new MatrixDimensionException(
+ "Non-square matrices cannot be raised to a power! "
+ + "This matrix is "
+ + this.getNumRows()
+ + " x "
+ + this.getNumCols());
+ }
+ Matrix<R, C> toReturn = new Matrix<>(new SimpleMatrix(this.getNumRows(), this.getNumCols()));
+ WPIMathJNI.pow(
+ this.m_storage.getDDRM().getData(),
+ this.getNumRows(),
+ exponent,
+ toReturn.m_storage.getDDRM().getData());
return toReturn;
}
@@ -380,7 +392,7 @@
/**
* Computes the Frobenius normal of the matrix.
*
- * <p>normF = Sqrt{ ∑<sub>i=1:m</sub> ∑<sub>j=1:n</sub> { a<sub>ij</sub><sup>2</sup>} }
+ * <p>normF = Sqrt{ ∑<sub>i=1:m</sub> ∑<sub>j=1:n</sub> { a<sub>ij</sub><sup>2</sup>} }
*
* @return The matrix's Frobenius normal.
*/
@@ -464,69 +476,90 @@
}
/**
- * Extracts a matrix of a given size and start position with new underlying
- * storage.
+ * Extracts a matrix of a given size and start position with new underlying storage.
*
+ * @param <R2> Number of rows to extract.
+ * @param <C2> Number of columns to extract.
* @param height The number of rows of the extracted matrix.
- * @param width The number of columns of the extracted matrix.
+ * @param width The number of columns of the extracted matrix.
* @param startingRow The starting row of the extracted matrix.
* @param startingCol The starting column of the extracted matrix.
* @return The extracted matrix.
*/
public final <R2 extends Num, C2 extends Num> Matrix<R2, C2> block(
Nat<R2> height, Nat<C2> width, int startingRow, int startingCol) {
- return new Matrix<>(this.m_storage.extractMatrix(
- startingRow,
- Objects.requireNonNull(height).getNum() + startingRow,
- startingCol,
- Objects.requireNonNull(width).getNum() + startingCol));
+ return new Matrix<>(
+ this.m_storage.extractMatrix(
+ startingRow,
+ startingRow + Objects.requireNonNull(height).getNum(),
+ startingCol,
+ startingCol + Objects.requireNonNull(width).getNum()));
+ }
+
+ /**
+ * Extracts a matrix of a given size and start position with new underlying storage.
+ *
+ * @param <R2> Number of rows to extract.
+ * @param <C2> Number of columns to extract.
+ * @param height The number of rows of the extracted matrix.
+ * @param width The number of columns of the extracted matrix.
+ * @param startingRow The starting row of the extracted matrix.
+ * @param startingCol The starting column of the extracted matrix.
+ * @return The extracted matrix.
+ */
+ public final <R2 extends Num, C2 extends Num> Matrix<R2, C2> block(
+ int height, int width, int startingRow, int startingCol) {
+ return new Matrix<R2, C2>(
+ this.m_storage.extractMatrix(
+ startingRow, startingRow + height, startingCol, startingCol + width));
}
/**
* Assign a matrix of a given size and start position.
*
+ * @param <R2> Rows in block assignment.
+ * @param <C2> Columns in block assignment.
* @param startingRow The row to start at.
- * @param startingCol The column to start at.
- * @param other The matrix to assign the block to.
+ * @param startingCol The column to start at.
+ * @param other The matrix to assign the block to.
*/
- public <R2 extends Num, C2 extends Num> void assignBlock(int startingRow, int startingCol,
- Matrix<R2, C2> other) {
+ public <R2 extends Num, C2 extends Num> void assignBlock(
+ int startingRow, int startingCol, Matrix<R2, C2> other) {
this.m_storage.insertIntoThis(
- startingRow,
- startingCol,
- Objects.requireNonNull(other).m_storage);
+ startingRow, startingCol, Objects.requireNonNull(other).m_storage);
}
/**
* Extracts a submatrix from the supplied matrix and inserts it in a submatrix in "this". The
* shape of "this" is used to determine the size of the matrix extracted.
*
+ * @param <R2> Number of rows to extract.
+ * @param <C2> Number of columns to extract.
* @param startingRow The starting row in the supplied matrix to extract the submatrix.
* @param startingCol The starting column in the supplied matrix to extract the submatrix.
- * @param other The matrix to extract the submatrix from.
+ * @param other The matrix to extract the submatrix from.
*/
- public <R2 extends Num, C2 extends Num> void extractFrom(int startingRow, int startingCol,
- Matrix<R2, C2> other) {
- CommonOps_DDRM.extract(other.m_storage.getDDRM(), startingRow, startingCol,
- this.m_storage.getDDRM());
+ public <R2 extends Num, C2 extends Num> void extractFrom(
+ int startingRow, int startingCol, Matrix<R2, C2> other) {
+ CommonOps_DDRM.extract(
+ other.m_storage.getDDRM(), startingRow, startingCol, this.m_storage.getDDRM());
}
/**
- * Decompose "this" matrix using Cholesky Decomposition. If the "this" matrix is zeros, it
- * will return the zero matrix.
+ * Decompose "this" matrix using Cholesky Decomposition. If the "this" matrix is zeros, it will
+ * return the zero matrix.
*
- * @param lowerTriangular Whether or not we want to decompose to the lower triangular
- * Cholesky matrix.
+ * @param lowerTriangular Whether or not we want to decompose to the lower triangular Cholesky
+ * matrix.
* @return The decomposed matrix.
* @throws RuntimeException if the matrix could not be decomposed(ie. is not positive
- * semidefinite).
+ * semidefinite).
*/
- @SuppressWarnings("PMD.AvoidThrowingRawExceptionTypes")
public Matrix<R, C> lltDecompose(boolean lowerTriangular) {
SimpleMatrix temp = m_storage.copy();
CholeskyDecomposition_F64<DMatrixRMaj> chol =
- DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
+ DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
if (!chol.decompose(temp.getMatrix())) {
// check that the input is not all zeros -- if they are, we special case and return all
// zeros.
@@ -539,8 +572,8 @@
return new Matrix<>(new SimpleMatrix(temp.numRows(), temp.numCols()));
}
- throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n"
- + m_storage.toString());
+ throw new RuntimeException(
+ "Cholesky decomposition failed! Input matrix:\n" + m_storage.toString());
}
return new Matrix<>(SimpleMatrix.wrap(chol.getT(null)));
@@ -578,8 +611,8 @@
}
/**
- * Entrypoint to the {@link MatBuilder} class for creation
- * of custom matrices with the given dimensions and contents.
+ * Entrypoint to the {@link MatBuilder} class for creation of custom matrices with the given
+ * dimensions and contents.
*
* @param rows The number of rows of the desired matrix.
* @param cols The number of columns of the desired matrix.
@@ -592,9 +625,11 @@
}
/**
- * Reassigns dimensions of a {@link Matrix} to allow for operations with
- * other matrices that have wildcard dimensions.
+ * Reassigns dimensions of a {@link Matrix} to allow for operations with other matrices that have
+ * wildcard dimensions.
*
+ * @param <R1> Row dimension to assign.
+ * @param <C1> Column dimension to assign.
* @param mat The {@link Matrix} to remove the dimensions from.
* @return The matrix with reassigned dimensions.
*/
@@ -606,40 +641,40 @@
/**
* Checks if another {@link Matrix} is identical to "this" one within a specified tolerance.
*
- * <p>This will check if each element is in tolerance of the corresponding element
- * from the other {@link Matrix} or if the elements have the same symbolic meaning. For two
- * elements to have the same symbolic meaning they both must be either Double.NaN,
- * Double.POSITIVE_INFINITY, or Double.NEGATIVE_INFINITY.
+ * <p>This will check if each element is in tolerance of the corresponding element from the other
+ * {@link Matrix} or if the elements have the same symbolic meaning. For two elements to have the
+ * same symbolic meaning they both must be either Double.NaN, Double.POSITIVE_INFINITY, or
+ * Double.NEGATIVE_INFINITY.
*
- * <p>NOTE:It is recommend to use {@link Matrix#isEqual(Matrix, double)} over this
- * method when checking if two matrices are equal as {@link Matrix#isEqual(Matrix, double)}
- * will return false if an element is uncountable. This method should only be used when
- * uncountable elements need to compared.
+ * <p>NOTE:It is recommend to use {@link Matrix#isEqual(Matrix, double)} over this method when
+ * checking if two matrices are equal as {@link Matrix#isEqual(Matrix, double)} will return false
+ * if an element is uncountable. This method should only be used when uncountable elements need to
+ * compared.
*
- * @param other The {@link Matrix} to check against this one.
+ * @param other The {@link Matrix} to check against this one.
* @param tolerance The tolerance to check equality with.
* @return true if this matrix is identical to the one supplied.
*/
public boolean isIdentical(Matrix<?, ?> other, double tolerance) {
- return MatrixFeatures_DDRM.isIdentical(this.m_storage.getDDRM(),
- other.m_storage.getDDRM(), tolerance);
+ return MatrixFeatures_DDRM.isIdentical(
+ this.m_storage.getDDRM(), other.m_storage.getDDRM(), tolerance);
}
/**
* Checks if another {@link Matrix} is equal to "this" within a specified tolerance.
*
- * <p>This will check if each element is in tolerance of the corresponding element
- * from the other {@link Matrix}.
+ * <p>This will check if each element is in tolerance of the corresponding element from the other
+ * {@link Matrix}.
*
* <p>tol ≥ |a<sub>ij</sub> - b<sub>ij</sub>|
*
- * @param other The {@link Matrix} to check against this one.
+ * @param other The {@link Matrix} to check against this one.
* @param tolerance The tolerance to check equality with.
* @return true if this matrix is equal to the one supplied.
*/
public boolean isEqual(Matrix<?, ?> other, double tolerance) {
- return MatrixFeatures_DDRM.isEquals(this.m_storage.getDDRM(),
- other.m_storage.getDDRM(), tolerance);
+ return MatrixFeatures_DDRM.isEquals(
+ this.m_storage.getDDRM(), other.m_storage.getDDRM(), tolerance);
}
@Override
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java b/wpimath/src/main/java/edu/wpi/first/math/MatrixUtils.java
similarity index 70%
rename from wpimath/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java
rename to wpimath/src/main/java/edu/wpi/first/math/MatrixUtils.java
index b3e4724..7600e31 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/MatrixUtils.java
@@ -1,18 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpiutil.math;
+package edu.wpi.first.math;
+import edu.wpi.first.math.numbers.N1;
import java.util.Objects;
-
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
@Deprecated
public final class MatrixUtils {
private MatrixUtils() {
@@ -31,7 +26,8 @@
@SuppressWarnings("LineLength")
public static <R extends Num, C extends Num> Matrix<R, C> zeros(Nat<R> rows, Nat<C> cols) {
return new Matrix<>(
- new SimpleMatrix(Objects.requireNonNull(rows).getNum(), Objects.requireNonNull(cols).getNum()));
+ new SimpleMatrix(
+ Objects.requireNonNull(rows).getNum(), Objects.requireNonNull(cols).getNum()));
}
/**
@@ -57,8 +53,8 @@
}
/**
- * Entrypoint to the MatBuilder class for creation
- * of custom matrices with the given dimensions and contents.
+ * Entrypoint to the MatBuilder class for creation of custom matrices with the given dimensions
+ * and contents.
*
* @param rows The number of rows of the desired matrix.
* @param cols The number of columns of the desired matrix.
@@ -71,8 +67,8 @@
}
/**
- * Entrypoint to the VecBuilder class for creation
- * of custom vectors with the given size and contents.
+ * Entrypoint to the VecBuilder class for creation of custom vectors with the given size and
+ * contents.
*
* @param dim The dimension of the vector.
* @param <D> The dimension of the vector as a generic.
diff --git a/wpimath/src/main/java/edu/wpi/first/math/Num.java b/wpimath/src/main/java/edu/wpi/first/math/Num.java
new file mode 100644
index 0000000..ef0fd2d
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/Num.java
@@ -0,0 +1,15 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+/** A number expressed as a java class. */
+public abstract class Num {
+ /**
+ * The number this is backing.
+ *
+ * @return The number represented by this class.
+ */
+ public abstract int getNum();
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/Pair.java b/wpimath/src/main/java/edu/wpi/first/math/Pair.java
new file mode 100644
index 0000000..d1a68c7
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/Pair.java
@@ -0,0 +1,28 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+public class Pair<A, B> {
+ private final A m_first;
+ private final B m_second;
+
+ public Pair(A first, B second) {
+ m_first = first;
+ m_second = second;
+ }
+
+ public A getFirst() {
+ return m_first;
+ }
+
+ public B getSecond() {
+ return m_second;
+ }
+
+ @SuppressWarnings("ParameterName")
+ public static <A, B> Pair<A, B> of(A a, B b) {
+ return new Pair<>(a, b);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java b/wpimath/src/main/java/edu/wpi/first/math/SimpleMatrixUtils.java
similarity index 73%
rename from wpimath/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java
rename to wpimath/src/main/java/edu/wpi/first/math/SimpleMatrixUtils.java
index 3f281d2..fc78dd1 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/SimpleMatrixUtils.java
@@ -1,14 +1,10 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpiutil.math;
+package edu.wpi.first.math;
import java.util.function.BiFunction;
-
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.NormOps_DDRM;
import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
@@ -16,11 +12,8 @@
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.math.WPIMathJNI;
-
public final class SimpleMatrixUtils {
- private SimpleMatrixUtils() {
- }
+ private SimpleMatrixUtils() {}
/**
* Compute the matrix exponential, e^M of the given matrix.
@@ -58,8 +51,11 @@
}
@SuppressWarnings({"LocalVariableName", "ParameterName", "LineLength"})
- private static SimpleMatrix dispatchPade(SimpleMatrix U, SimpleMatrix V,
- int nSquarings, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) {
+ private static SimpleMatrix dispatchPade(
+ SimpleMatrix U,
+ SimpleMatrix V,
+ int nSquarings,
+ BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) {
SimpleMatrix P = U.plus(V);
SimpleMatrix Q = U.negative().plus(V);
@@ -74,7 +70,7 @@
@SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade3(SimpleMatrix A) {
- double[] b = new double[]{120, 60, 12, 1};
+ double[] b = new double[] {120, 60, 12, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
@@ -85,7 +81,7 @@
@SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade5(SimpleMatrix A) {
- double[] b = new double[]{30240, 15120, 3360, 420, 30, 1};
+ double[] b = new double[] {30240, 15120, 3360, 420, 30, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
@@ -98,24 +94,26 @@
@SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade7(SimpleMatrix A) {
- double[] b = new double[]{17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1};
+ double[] b = new double[] {17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix U =
- A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
+ A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V =
- A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
+ A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName", "LineLength"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade9(SimpleMatrix A) {
- double[] b = new double[]{17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240,
- 2162160, 110880, 3960, 90, 1};
+ double[] b =
+ new double[] {
+ 17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1
+ };
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
@@ -123,18 +121,41 @@
SimpleMatrix A8 = A6.mult(A2);
SimpleMatrix U =
- A.mult(A8.scale(b[9]).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
+ A.mult(
+ A8.scale(b[9])
+ .plus(A6.scale(b[7]))
+ .plus(A4.scale(b[5]))
+ .plus(A2.scale(b[3]))
+ .plus(ident.scale(b[1])));
SimpleMatrix V =
- A8.scale(b[8]).plus(A6.scale(b[6])).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
+ A8.scale(b[8])
+ .plus(A6.scale(b[6]))
+ .plus(A4.scale(b[4]))
+ .plus(A2.scale(b[2]))
+ .plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade13(SimpleMatrix A) {
- double[] b = new double[]{64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
- 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
- 33522128640.0, 1323241920, 40840800, 960960, 16380, 182, 1};
+ double[] b =
+ new double[] {
+ 64764752532480000.0,
+ 32382376266240000.0,
+ 7771770303897600.0,
+ 1187353796428800.0,
+ 129060195264000.0,
+ 10559470521600.0,
+ 670442572800.0,
+ 33522128640.0,
+ 1323241920,
+ 40840800,
+ 960960,
+ 16380,
+ 182,
+ 1
+ };
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
@@ -142,9 +163,17 @@
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix U =
- A.mult(A6.scale(b[13]).plus(A4.scale(b[11])).plus(A2.scale(b[9])).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
+ A.mult(
+ A6.scale(b[13])
+ .plus(A4.scale(b[11]))
+ .plus(A2.scale(b[9]))
+ .plus(A6.scale(b[7]))
+ .plus(A4.scale(b[5]))
+ .plus(A2.scale(b[3]))
+ .plus(ident.scale(b[1])));
SimpleMatrix V =
- A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))).plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
+ A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8])))
+ .plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
return new Pair<>(U, V);
}
@@ -171,7 +200,7 @@
* @param src The matrix to decompose.
* @return The decomposed matrix.
* @throws RuntimeException if the matrix could not be decomposed (ie. is not positive
- * semidefinite).
+ * semidefinite).
*/
public static SimpleMatrix lltDecompose(SimpleMatrix src) {
return lltDecompose(src, false);
@@ -185,14 +214,13 @@
* @param lowerTriangular if we want to decompose to the lower triangular Cholesky matrix.
* @return The decomposed matrix.
* @throws RuntimeException if the matrix could not be decomposed (ie. is not positive
- * semidefinite).
+ * semidefinite).
*/
- @SuppressWarnings("PMD.AvoidThrowingRawExceptionTypes")
public static SimpleMatrix lltDecompose(SimpleMatrix src, boolean lowerTriangular) {
SimpleMatrix temp = src.copy();
CholeskyDecomposition_F64<DMatrixRMaj> chol =
- DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
+ DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
if (!chol.decompose(temp.getMatrix())) {
// check that the input is not all zeros -- if they are, we special case and return all
// zeros.
@@ -218,11 +246,9 @@
* @return the exponential of A.
*/
@SuppressWarnings("ParameterName")
- public static SimpleMatrix exp(
- SimpleMatrix A) {
+ public static SimpleMatrix exp(SimpleMatrix A) {
SimpleMatrix toReturn = new SimpleMatrix(A.numRows(), A.numRows());
WPIMathJNI.exp(A.getDDRM().getData(), A.numRows(), toReturn.getDDRM().getData());
return toReturn;
}
-
}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/StateSpaceUtil.java b/wpimath/src/main/java/edu/wpi/first/math/StateSpaceUtil.java
new file mode 100644
index 0000000..8baf401
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/StateSpaceUtil.java
@@ -0,0 +1,202 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.math.numbers.N4;
+import java.util.Random;
+import org.ejml.simple.SimpleMatrix;
+
+@SuppressWarnings("ParameterName")
+public final class StateSpaceUtil {
+ private static Random rand = new Random();
+
+ private StateSpaceUtil() {
+ // Utility class
+ }
+
+ /**
+ * Creates a covariance matrix from the given vector for use with Kalman filters.
+ *
+ * <p>Each element is squared and placed on the covariance matrix diagonal.
+ *
+ * @param <States> Num representing the states of the system.
+ * @param states A Nat representing the states of the system.
+ * @param stdDevs For a Q matrix, its elements are the standard deviations of each state from how
+ * the model behaves. For an R matrix, its elements are the standard deviations for each
+ * output measurement.
+ * @return Process noise or measurement noise covariance matrix.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num> Matrix<States, States> makeCovarianceMatrix(
+ Nat<States> states, Matrix<States, N1> stdDevs) {
+ var result = new Matrix<>(states, states);
+ for (int i = 0; i < states.getNum(); i++) {
+ result.set(i, i, Math.pow(stdDevs.get(i, 0), 2));
+ }
+ return result;
+ }
+
+ /**
+ * Creates a vector of normally distributed white noise with the given noise intensities for each
+ * element.
+ *
+ * @param <N> Num representing the dimensionality of the noise vector to create.
+ * @param stdDevs A matrix whose elements are the standard deviations of each element of the noise
+ * vector.
+ * @return White noise vector.
+ */
+ public static <N extends Num> Matrix<N, N1> makeWhiteNoiseVector(Matrix<N, N1> stdDevs) {
+ Matrix<N, N1> result = new Matrix<>(new SimpleMatrix(stdDevs.getNumRows(), 1));
+ for (int i = 0; i < stdDevs.getNumRows(); i++) {
+ result.set(i, 0, rand.nextGaussian() * stdDevs.get(i, 0));
+ }
+ return result;
+ }
+
+ /**
+ * Creates a cost matrix from the given vector for use with LQR.
+ *
+ * <p>The cost matrix is constructed using Bryson's rule. The inverse square of each element in
+ * the input is taken and placed on the cost matrix diagonal.
+ *
+ * @param <States> Nat representing the states of the system.
+ * @param costs An array. For a Q matrix, its elements are the maximum allowed excursions of the
+ * states from the reference. For an R matrix, its elements are the maximum allowed excursions
+ * of the control inputs from no actuation.
+ * @return State excursion or control effort cost matrix.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num> Matrix<States, States> makeCostMatrix(
+ Matrix<States, N1> costs) {
+ Matrix<States, States> result =
+ new Matrix<>(new SimpleMatrix(costs.getNumRows(), costs.getNumRows()));
+ result.fill(0.0);
+
+ for (int i = 0; i < costs.getNumRows(); i++) {
+ result.set(i, i, 1.0 / (Math.pow(costs.get(i, 0), 2)));
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns true if (A, B) is a stabilizable pair.
+ *
+ * <p>(A, B) is stabilizable if and only if the uncontrollable eigenvalues of A, if any, have
+ * absolute values less than one, where an eigenvalue is uncontrollable if rank(λI - A, B) %3C n
+ * where n is the number of states.
+ *
+ * @param <States> Num representing the size of A.
+ * @param <Inputs> Num representing the columns of B.
+ * @param A System matrix.
+ * @param B Input matrix.
+ * @return If the system is stabilizable.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num, Inputs extends Num> boolean isStabilizable(
+ Matrix<States, States> A, Matrix<States, Inputs> B) {
+ return WPIMathJNI.isStabilizable(A.getNumRows(), B.getNumCols(), A.getData(), B.getData());
+ }
+
+ /**
+ * Returns true if (A, C) is a detectable pair.
+ *
+ * <p>(A, C) is detectable if and only if the unobservable eigenvalues of A, if any, have absolute
+ * values less than one, where an eigenvalue is unobservable if rank(λI - A; C) %3C n where n is
+ * the number of states.
+ *
+ * @param <States> Num representing the size of A.
+ * @param <Outputs> Num representing the rows of C.
+ * @param A System matrix.
+ * @param C Output matrix.
+ * @return If the system is detectable.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num, Outputs extends Num> boolean isDetectable(
+ Matrix<States, States> A, Matrix<Outputs, States> C) {
+ return WPIMathJNI.isStabilizable(
+ A.getNumRows(), C.getNumRows(), A.transpose().getData(), C.transpose().getData());
+ }
+
+ /**
+ * Convert a {@link Pose2d} to a vector of [x, y, theta], where theta is in radians.
+ *
+ * @param pose A pose to convert to a vector.
+ * @return The given pose in vector form, with the third element, theta, in radians.
+ */
+ public static Matrix<N3, N1> poseToVector(Pose2d pose) {
+ return VecBuilder.fill(pose.getX(), pose.getY(), pose.getRotation().getRadians());
+ }
+
+ /**
+ * Clamp the input u to the min and max.
+ *
+ * @param u The input to clamp.
+ * @param umin The minimum input magnitude.
+ * @param umax The maximum input magnitude.
+ * @param <I> The number of inputs.
+ * @return The clamped input.
+ */
+ @SuppressWarnings({"ParameterName", "LocalVariableName"})
+ public static <I extends Num> Matrix<I, N1> clampInputMaxMagnitude(
+ Matrix<I, N1> u, Matrix<I, N1> umin, Matrix<I, N1> umax) {
+ var result = new Matrix<I, N1>(new SimpleMatrix(u.getNumRows(), 1));
+ for (int i = 0; i < u.getNumRows(); i++) {
+ result.set(i, 0, MathUtil.clamp(u.get(i, 0), umin.get(i, 0), umax.get(i, 0)));
+ }
+ return result;
+ }
+
+ /**
+ * Normalize all inputs if any excedes the maximum magnitude. Useful for systems such as
+ * differential drivetrains.
+ *
+ * @param u The input vector.
+ * @param maxMagnitude The maximum magnitude any input can have.
+ * @param <I> The number of inputs.
+ * @return The normalizedInput
+ */
+ public static <I extends Num> Matrix<I, N1> normalizeInputVector(
+ Matrix<I, N1> u, double maxMagnitude) {
+ double maxValue = u.maxAbs();
+ boolean isCapped = maxValue > maxMagnitude;
+
+ if (isCapped) {
+ return u.times(maxMagnitude / maxValue);
+ }
+ return u;
+ }
+
+ /**
+ * Convert a {@link Pose2d} to a vector of [x, y, cos(theta), sin(theta)], where theta is in
+ * radians.
+ *
+ * @param pose A pose to convert to a vector.
+ * @return The given pose in as a 4x1 vector of x, y, cos(theta), and sin(theta).
+ */
+ public static Matrix<N4, N1> poseTo4dVector(Pose2d pose) {
+ return VecBuilder.fill(
+ pose.getTranslation().getX(),
+ pose.getTranslation().getY(),
+ pose.getRotation().getCos(),
+ pose.getRotation().getSin());
+ }
+
+ /**
+ * Convert a {@link Pose2d} to a vector of [x, y, theta], where theta is in radians.
+ *
+ * @param pose A pose to convert to a vector.
+ * @return The given pose in vector form, with the third element, theta, in radians.
+ */
+ public static Matrix<N3, N1> poseTo3dVector(Pose2d pose) {
+ return VecBuilder.fill(
+ pose.getTranslation().getX(),
+ pose.getTranslation().getY(),
+ pose.getRotation().getRadians());
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java b/wpimath/src/main/java/edu/wpi/first/math/VecBuilder.java
similarity index 70%
rename from wpimath/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java
rename to wpimath/src/main/java/edu/wpi/first/math/VecBuilder.java
index 98200b7..670611a 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/VecBuilder.java
@@ -1,24 +1,19 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpiutil.math;
+package edu.wpi.first.math;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N10;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-import edu.wpi.first.wpiutil.math.numbers.N3;
-import edu.wpi.first.wpiutil.math.numbers.N4;
-import edu.wpi.first.wpiutil.math.numbers.N5;
-import edu.wpi.first.wpiutil.math.numbers.N6;
-import edu.wpi.first.wpiutil.math.numbers.N7;
-import edu.wpi.first.wpiutil.math.numbers.N8;
-import edu.wpi.first.wpiutil.math.numbers.N9;
-
-
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N10;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.math.numbers.N4;
+import edu.wpi.first.math.numbers.N5;
+import edu.wpi.first.math.numbers.N6;
+import edu.wpi.first.math.numbers.N7;
+import edu.wpi.first.math.numbers.N8;
+import edu.wpi.first.math.numbers.N9;
/**
* A specialization of {@link MatBuilder} for constructing vectors (Nx1 matrices).
@@ -38,6 +33,7 @@
* Returns a 1x1 vector containing the given elements.
*
* @param n1 the first element.
+ * @return 1x1 vector
*/
public static Vector<N1> fill(double n1) {
return new VecBuilder<>(Nat.N1()).fillVec(n1);
@@ -48,6 +44,7 @@
*
* @param n1 the first element.
* @param n2 the second element.
+ * @return 2x1 vector
*/
public static Vector<N2> fill(double n1, double n2) {
return new VecBuilder<>(Nat.N2()).fillVec(n1, n2);
@@ -59,6 +56,7 @@
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
+ * @return 3x1 vector
*/
public static Vector<N3> fill(double n1, double n2, double n3) {
return new VecBuilder<>(Nat.N3()).fillVec(n1, n2, n3);
@@ -71,6 +69,7 @@
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
+ * @return 4x1 vector
*/
public static Vector<N4> fill(double n1, double n2, double n3, double n4) {
return new VecBuilder<>(Nat.N4()).fillVec(n1, n2, n3, n4);
@@ -84,6 +83,7 @@
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
+ * @return 5x1 vector
*/
public static Vector<N5> fill(double n1, double n2, double n3, double n4, double n5) {
return new VecBuilder<>(Nat.N5()).fillVec(n1, n2, n3, n4, n5);
@@ -98,9 +98,9 @@
* @param n4 the fourth element.
* @param n5 the fifth element.
* @param n6 the sixth element.
+ * @return 6x1 vector
*/
- public static Vector<N6> fill(double n1, double n2, double n3, double n4, double n5,
- double n6) {
+ public static Vector<N6> fill(double n1, double n2, double n3, double n4, double n5, double n6) {
return new VecBuilder<>(Nat.N6()).fillVec(n1, n2, n3, n4, n5, n6);
}
@@ -114,9 +114,10 @@
* @param n5 the fifth element.
* @param n6 the sixth element.
* @param n7 the seventh element.
+ * @return 7x1 vector
*/
- public static Vector<N7> fill(double n1, double n2, double n3, double n4, double n5,
- double n6, double n7) {
+ public static Vector<N7> fill(
+ double n1, double n2, double n3, double n4, double n5, double n6, double n7) {
return new VecBuilder<>(Nat.N7()).fillVec(n1, n2, n3, n4, n5, n6, n7);
}
@@ -131,9 +132,10 @@
* @param n6 the sixth element.
* @param n7 the seventh element.
* @param n8 the eighth element.
+ * @return 8x1 vector
*/
- public static Vector<N8> fill(double n1, double n2, double n3, double n4, double n5,
- double n6, double n7, double n8) {
+ public static Vector<N8> fill(
+ double n1, double n2, double n3, double n4, double n5, double n6, double n7, double n8) {
return new VecBuilder<>(Nat.N8()).fillVec(n1, n2, n3, n4, n5, n6, n7, n8);
}
@@ -149,9 +151,18 @@
* @param n7 the seventh element.
* @param n8 the eighth element.
* @param n9 the ninth element.
+ * @return 9x1 vector
*/
- public static Vector<N9> fill(double n1, double n2, double n3, double n4, double n5,
- double n6, double n7, double n8, double n9) {
+ public static Vector<N9> fill(
+ double n1,
+ double n2,
+ double n3,
+ double n4,
+ double n5,
+ double n6,
+ double n7,
+ double n8,
+ double n9) {
return new VecBuilder<>(Nat.N9()).fillVec(n1, n2, n3, n4, n5, n6, n7, n8, n9);
}
@@ -168,10 +179,19 @@
* @param n8 the eighth element.
* @param n9 the ninth element.
* @param n10 the tenth element.
+ * @return 10x1 vector
*/
- @SuppressWarnings("PMD.ExcessiveParameterList")
- public static Vector<N10> fill(double n1, double n2, double n3, double n4, double n5,
- double n6, double n7, double n8, double n9, double n10) {
+ public static Vector<N10> fill(
+ double n1,
+ double n2,
+ double n3,
+ double n4,
+ double n5,
+ double n6,
+ double n7,
+ double n8,
+ double n9,
+ double n10) {
return new VecBuilder<>(Nat.N10()).fillVec(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10);
}
}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/Vector.java b/wpimath/src/main/java/edu/wpi/first/math/Vector.java
new file mode 100644
index 0000000..9b06e71
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/Vector.java
@@ -0,0 +1,64 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+import edu.wpi.first.math.numbers.N1;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices.
+ *
+ * <p>This class is intended to be used alongside the state space library.
+ *
+ * @param <R> The number of rows in this matrix.
+ */
+public class Vector<R extends Num> extends Matrix<R, N1> {
+ /**
+ * Constructs an empty zero vector of the given dimensions.
+ *
+ * @param rows The number of rows of the vector.
+ */
+ public Vector(Nat<R> rows) {
+ super(rows, Nat.N1());
+ }
+
+ /**
+ * Constructs a new {@link Vector} with the given storage. Caller should make sure that the
+ * provided generic bounds match the shape of the provided {@link Vector}.
+ *
+ * <p>NOTE:It is not recommended to use this constructor unless the {@link SimpleMatrix} API is
+ * absolutely necessary due to the desired function not being accessible through the {@link
+ * Vector} wrapper.
+ *
+ * @param storage The {@link SimpleMatrix} to back this vector.
+ */
+ public Vector(SimpleMatrix storage) {
+ super(storage);
+ }
+
+ /**
+ * Constructs a new vector with the storage of the supplied matrix.
+ *
+ * @param other The {@link Vector} to copy the storage of.
+ */
+ public Vector(Matrix<R, N1> other) {
+ super(other);
+ }
+
+ @Override
+ public Vector<R> times(double value) {
+ return new Vector<>(this.m_storage.scale(value));
+ }
+
+ @Override
+ public Vector<R> div(int value) {
+ return new Vector<>(this.m_storage.divide(value));
+ }
+
+ @Override
+ public Vector<R> div(double value) {
+ return new Vector<>(this.m_storage.divide(value));
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java b/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java
index 30984ac..54445d3 100644
--- a/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java
@@ -1,17 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
+import edu.wpi.first.util.RuntimeLoader;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
-import edu.wpi.first.wpiutil.RuntimeLoader;
-
public final class WPIMathJNI {
static boolean libraryLoaded = false;
static RuntimeLoader<WPIMathJNI> loader = null;
@@ -19,8 +15,9 @@
static {
if (Helper.getExtractOnStaticLoad()) {
try {
- loader = new RuntimeLoader<>("wpimathjni", RuntimeLoader.getDefaultExtractionRoot(),
- WPIMathJNI.class);
+ loader =
+ new RuntimeLoader<>(
+ "wpimathjni", RuntimeLoader.getDefaultExtractionRoot(), WPIMathJNI.class);
loader.loadLibrary();
} catch (IOException ex) {
ex.printStackTrace();
@@ -39,8 +36,9 @@
if (libraryLoaded) {
return;
}
- loader = new RuntimeLoader<>("wpimathjni", RuntimeLoader.getDefaultExtractionRoot(),
- WPIMathJNI.class);
+ loader =
+ new RuntimeLoader<>(
+ "wpimathjni", RuntimeLoader.getDefaultExtractionRoot(), WPIMathJNI.class);
loader.loadLibrary();
libraryLoaded = true;
}
@@ -48,47 +46,85 @@
/**
* Solves the discrete alegebraic Riccati equation.
*
- * @param A Array containing elements of A in row-major order.
- * @param B Array containing elements of B in row-major order.
- * @param Q Array containing elements of Q in row-major order.
- * @param R Array containing elements of R in row-major order.
+ * @param A Array containing elements of A in row-major order.
+ * @param B Array containing elements of B in row-major order.
+ * @param Q Array containing elements of Q in row-major order.
+ * @param R Array containing elements of R in row-major order.
* @param states Number of states in A matrix.
* @param inputs Number of inputs in B matrix.
- * @param S Array storage for DARE solution.
+ * @param S Array storage for DARE solution.
*/
public static native void discreteAlgebraicRiccatiEquation(
- double[] A,
- double[] B,
- double[] Q,
- double[] R,
- int states,
- int inputs,
- double[] S);
+ double[] A, double[] B, double[] Q, double[] R, int states, int inputs, double[] S);
/**
* Computes the matrix exp.
*
- * @param src Array of elements of the matrix to be exponentiated.
- * @param rows how many rows there are.
- * @param dst Array where the result will be stored.
+ * @param src Array of elements of the matrix to be exponentiated.
+ * @param rows How many rows there are.
+ * @param dst Array where the result will be stored.
*/
public static native void exp(double[] src, int rows, double[] dst);
/**
+ * Computes the matrix pow.
+ *
+ * @param src Array of elements of the matrix to be raised to a power.
+ * @param rows How many rows there are.
+ * @param exponent The exponent.
+ * @param dst Array where the result will be stored.
+ */
+ public static native void pow(double[] src, int rows, double exponent, double[] dst);
+
+ /**
* Returns true if (A, B) is a stabilizable pair.
*
- * <p>(A,B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
- * any, have absolute values less than one, where an eigenvalue is
- * uncontrollable if rank(lambda * I - A, B) < n where n is number of states.
+ * <p>(A, B) is stabilizable if and only if the uncontrollable eigenvalues of A, if any, have
+ * absolute values less than one, where an eigenvalue is uncontrollable if rank(lambda * I - A, B)
+ * < n where n is the number of states.
*
* @param states the number of states of the system.
* @param inputs the number of inputs to the system.
- * @param A System matrix.
- * @param B Input matrix.
+ * @param A System matrix.
+ * @param B Input matrix.
* @return If the system is stabilizable.
*/
public static native boolean isStabilizable(int states, int inputs, double[] A, double[] B);
+ /**
+ * Loads a Pathweaver JSON.
+ *
+ * @param path The path to the JSON.
+ * @return A double array with the trajectory states from the JSON.
+ * @throws IOException if the JSON could not be read.
+ */
+ public static native double[] fromPathweaverJson(String path) throws IOException;
+
+ /**
+ * Converts a trajectory into a Pathweaver JSON and saves it.
+ *
+ * @param elements The elements of the trajectory.
+ * @param path The location to save the JSON to.
+ * @throws IOException if the JSON could not be written.
+ */
+ public static native void toPathweaverJson(double[] elements, String path) throws IOException;
+
+ /**
+ * Deserializes a trajectory JSON into a double[] of trajectory elements.
+ *
+ * @param json The JSON containing the serialized trajectory.
+ * @return A double array with the trajectory states.
+ */
+ public static native double[] deserializeTrajectory(String json);
+
+ /**
+ * Serializes the trajectory into a JSON string.
+ *
+ * @param elements The elements of the trajectory.
+ * @return A JSON containing the serialized trajectory.
+ */
+ public static native String serializeTrajectory(double[] elements);
+
public static class Helper {
private static AtomicBoolean extractOnStaticLoad = new AtomicBoolean(true);
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/ArmFeedforward.java b/wpimath/src/main/java/edu/wpi/first/math/controller/ArmFeedforward.java
new file mode 100644
index 0000000..2991511
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/ArmFeedforward.java
@@ -0,0 +1,138 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+/**
+ * A helper class that computes feedforward outputs for a simple arm (modeled as a motor acting
+ * against the force of gravity on a beam suspended at an angle).
+ */
+@SuppressWarnings("MemberName")
+public class ArmFeedforward {
+ public final double ks;
+ public final double kcos;
+ public final double kv;
+ public final double ka;
+
+ /**
+ * Creates a new ArmFeedforward with the specified gains. Units of the gain values will dictate
+ * units of the computed feedforward.
+ *
+ * @param ks The static gain.
+ * @param kcos The gravity gain.
+ * @param kv The velocity gain.
+ * @param ka The acceleration gain.
+ */
+ public ArmFeedforward(double ks, double kcos, double kv, double ka) {
+ this.ks = ks;
+ this.kcos = kcos;
+ this.kv = kv;
+ this.ka = ka;
+ }
+
+ /**
+ * Creates a new ArmFeedforward with the specified gains. Acceleration gain is defaulted to zero.
+ * Units of the gain values will dictate units of the computed feedforward.
+ *
+ * @param ks The static gain.
+ * @param kcos The gravity gain.
+ * @param kv The velocity gain.
+ */
+ public ArmFeedforward(double ks, double kcos, double kv) {
+ this(ks, kcos, kv, 0);
+ }
+
+ /**
+ * Calculates the feedforward from the gains and setpoints.
+ *
+ * @param positionRadians The position (angle) setpoint.
+ * @param velocityRadPerSec The velocity setpoint.
+ * @param accelRadPerSecSquared The acceleration setpoint.
+ * @return The computed feedforward.
+ */
+ public double calculate(
+ double positionRadians, double velocityRadPerSec, double accelRadPerSecSquared) {
+ return ks * Math.signum(velocityRadPerSec)
+ + kcos * Math.cos(positionRadians)
+ + kv * velocityRadPerSec
+ + ka * accelRadPerSecSquared;
+ }
+
+ /**
+ * Calculates the feedforward from the gains and velocity setpoint (acceleration is assumed to be
+ * zero).
+ *
+ * @param positionRadians The position (angle) setpoint.
+ * @param velocity The velocity setpoint.
+ * @return The computed feedforward.
+ */
+ public double calculate(double positionRadians, double velocity) {
+ return calculate(positionRadians, velocity, 0);
+ }
+
+ // Rearranging the main equation from the calculate() method yields the
+ // formulas for the methods below:
+
+ /**
+ * Calculates the maximum achievable velocity given a maximum voltage supply, a position, and an
+ * acceleration. Useful for ensuring that velocity and acceleration constraints for a trapezoidal
+ * profile are simultaneously achievable - enter the acceleration constraint, and this will give
+ * you a simultaneously-achievable velocity constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the arm.
+ * @param angle The angle of the arm.
+ * @param acceleration The acceleration of the arm.
+ * @return The maximum possible velocity at the given acceleration and angle.
+ */
+ public double maxAchievableVelocity(double maxVoltage, double angle, double acceleration) {
+ // Assume max velocity is positive
+ return (maxVoltage - ks - Math.cos(angle) * kcos - acceleration * ka) / kv;
+ }
+
+ /**
+ * Calculates the minimum achievable velocity given a maximum voltage supply, a position, and an
+ * acceleration. Useful for ensuring that velocity and acceleration constraints for a trapezoidal
+ * profile are simultaneously achievable - enter the acceleration constraint, and this will give
+ * you a simultaneously-achievable velocity constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the arm.
+ * @param angle The angle of the arm.
+ * @param acceleration The acceleration of the arm.
+ * @return The minimum possible velocity at the given acceleration and angle.
+ */
+ public double minAchievableVelocity(double maxVoltage, double angle, double acceleration) {
+ // Assume min velocity is negative, ks flips sign
+ return (-maxVoltage + ks - Math.cos(angle) * kcos - acceleration * ka) / kv;
+ }
+
+ /**
+ * Calculates the maximum achievable acceleration given a maximum voltage supply, a position, and
+ * a velocity. Useful for ensuring that velocity and acceleration constraints for a trapezoidal
+ * profile are simultaneously achievable - enter the velocity constraint, and this will give you a
+ * simultaneously-achievable acceleration constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the arm.
+ * @param angle The angle of the arm.
+ * @param velocity The velocity of the arm.
+ * @return The maximum possible acceleration at the given velocity.
+ */
+ public double maxAchievableAcceleration(double maxVoltage, double angle, double velocity) {
+ return (maxVoltage - ks * Math.signum(velocity) - Math.cos(angle) * kcos - velocity * kv) / ka;
+ }
+
+ /**
+ * Calculates the minimum achievable acceleration given a maximum voltage supply, a position, and
+ * a velocity. Useful for ensuring that velocity and acceleration constraints for a trapezoidal
+ * profile are simultaneously achievable - enter the velocity constraint, and this will give you a
+ * simultaneously-achievable acceleration constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the arm.
+ * @param angle The angle of the arm.
+ * @param velocity The velocity of the arm.
+ * @return The minimum possible acceleration at the given velocity.
+ */
+ public double minAchievableAcceleration(double maxVoltage, double angle, double velocity) {
+ return maxAchievableAcceleration(-maxVoltage, angle, velocity);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/ControlAffinePlantInversionFeedforward.java b/wpimath/src/main/java/edu/wpi/first/math/controller/ControlAffinePlantInversionFeedforward.java
new file mode 100644
index 0000000..d4beabd
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/ControlAffinePlantInversionFeedforward.java
@@ -0,0 +1,193 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.system.NumericalJacobian;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+/**
+ * Constructs a control-affine plant inversion model-based feedforward from given model dynamics.
+ *
+ * <p>If given the vector valued function as f(x, u) where x is the state vector and u is the input
+ * vector, the B matrix(continuous input matrix) is calculated through a {@link
+ * edu.wpi.first.math.system.NumericalJacobian}. In this case f has to be control-affine (of the
+ * form f(x) + Bu).
+ *
+ * <p>The feedforward is calculated as <strong> u_ff = B<sup>+</sup> (rDot - f(x))</strong>, where
+ * <strong> B<sup>+</sup> </strong> is the pseudoinverse of B.
+ *
+ * <p>This feedforward does not account for a dynamic B matrix, B is either determined or supplied
+ * when the feedforward is created and remains constant.
+ *
+ * <p>For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ */
+@SuppressWarnings({"ParameterName", "LocalVariableName", "MemberName", "ClassTypeParameterName"})
+public class ControlAffinePlantInversionFeedforward<States extends Num, Inputs extends Num> {
+ /** The current reference state. */
+ @SuppressWarnings("MemberName")
+ private Matrix<States, N1> m_r;
+
+ /** The computed feedforward. */
+ private Matrix<Inputs, N1> m_uff;
+
+ @SuppressWarnings("MemberName")
+ private final Matrix<States, Inputs> m_B;
+
+ private final Nat<Inputs> m_inputs;
+
+ private final double m_dt;
+
+ /** The model dynamics. */
+ private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
+
+ /**
+ * Constructs a feedforward with given model dynamics as a function of state and input.
+ *
+ * @param states A {@link Nat} representing the number of states.
+ * @param inputs A {@link Nat} representing the number of inputs.
+ * @param f A vector-valued function of x, the state, and u, the input, that returns the
+ * derivative of the state vector. HAS to be control-affine (of the form f(x) + Bu).
+ * @param dtSeconds The timestep between calls of calculate().
+ */
+ public ControlAffinePlantInversionFeedforward(
+ Nat<States> states,
+ Nat<Inputs> inputs,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ double dtSeconds) {
+ this.m_dt = dtSeconds;
+ this.m_f = f;
+ this.m_inputs = inputs;
+
+ this.m_B =
+ NumericalJacobian.numericalJacobianU(
+ states, inputs, m_f, new Matrix<>(states, Nat.N1()), new Matrix<>(inputs, Nat.N1()));
+
+ m_r = new Matrix<>(states, Nat.N1());
+ m_uff = new Matrix<>(inputs, Nat.N1());
+
+ reset();
+ }
+
+ /**
+ * Constructs a feedforward with given model dynamics as a function of state, and the plant's
+ * B(continuous input matrix) matrix.
+ *
+ * @param states A {@link Nat} representing the number of states.
+ * @param inputs A {@link Nat} representing the number of inputs.
+ * @param f A vector-valued function of x, the state, that returns the derivative of the state
+ * vector.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param dtSeconds The timestep between calls of calculate().
+ */
+ public ControlAffinePlantInversionFeedforward(
+ Nat<States> states,
+ Nat<Inputs> inputs,
+ Function<Matrix<States, N1>, Matrix<States, N1>> f,
+ Matrix<States, Inputs> B,
+ double dtSeconds) {
+ this.m_dt = dtSeconds;
+ this.m_inputs = inputs;
+
+ this.m_f = (x, u) -> f.apply(x);
+ this.m_B = B;
+
+ m_r = new Matrix<>(states, Nat.N1());
+ m_uff = new Matrix<>(inputs, Nat.N1());
+
+ reset();
+ }
+
+ /**
+ * Returns the previously calculated feedforward as an input vector.
+ *
+ * @return The calculated feedforward.
+ */
+ public Matrix<Inputs, N1> getUff() {
+ return m_uff;
+ }
+
+ /**
+ * Returns an element of the previously calculated feedforward.
+ *
+ * @param row Row of uff.
+ * @return The row of the calculated feedforward.
+ */
+ public double getUff(int row) {
+ return m_uff.get(row, 0);
+ }
+
+ /**
+ * Returns the current reference vector r.
+ *
+ * @return The current reference vector.
+ */
+ public Matrix<States, N1> getR() {
+ return m_r;
+ }
+
+ /**
+ * Returns an element of the current reference vector r.
+ *
+ * @param row Row of r.
+ * @return The row of the current reference vector.
+ */
+ public double getR(int row) {
+ return m_r.get(row, 0);
+ }
+
+ /**
+ * Resets the feedforward with a specified initial state vector.
+ *
+ * @param initialState The initial state vector.
+ */
+ public void reset(Matrix<States, N1> initialState) {
+ m_r = initialState;
+ m_uff.fill(0.0);
+ }
+
+ /** Resets the feedforward with a zero initial state vector. */
+ public void reset() {
+ m_r.fill(0.0);
+ m_uff.fill(0.0);
+ }
+
+ /**
+ * Calculate the feedforward with only the desired future reference. This uses the internally
+ * stored "current" reference.
+ *
+ * <p>If this method is used the initial state of the system is the one set using {@link
+ * LinearPlantInversionFeedforward#reset(Matrix)}. If the initial state is not set it defaults to
+ * a zero vector.
+ *
+ * @param nextR The reference state of the future timestep (k + dt).
+ * @return The calculated feedforward.
+ */
+ public Matrix<Inputs, N1> calculate(Matrix<States, N1> nextR) {
+ return calculate(m_r, nextR);
+ }
+
+ /**
+ * Calculate the feedforward with current and future reference vectors.
+ *
+ * @param r The reference state of the current timestep (k).
+ * @param nextR The reference state of the future timestep (k + dt).
+ * @return The calculated feedforward.
+ */
+ @SuppressWarnings({"ParameterName", "LocalVariableName"})
+ public Matrix<Inputs, N1> calculate(Matrix<States, N1> r, Matrix<States, N1> nextR) {
+ var rDot = (nextR.minus(r)).div(m_dt);
+
+ m_uff = m_B.solve(rDot.minus(m_f.apply(r, new Matrix<>(m_inputs, Nat.N1()))));
+
+ m_r = nextR;
+ return m_uff;
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ElevatorFeedforward.java b/wpimath/src/main/java/edu/wpi/first/math/controller/ElevatorFeedforward.java
similarity index 62%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ElevatorFeedforward.java
rename to wpimath/src/main/java/edu/wpi/first/math/controller/ElevatorFeedforward.java
index 0b52c14..248015f 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ElevatorFeedforward.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/ElevatorFeedforward.java
@@ -1,15 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.controller;
+package edu.wpi.first.math.controller;
/**
- * A helper class that computes feedforward outputs for a simple elevator (modeled as a motor
- * acting against the force of gravity).
+ * A helper class that computes feedforward outputs for a simple elevator (modeled as a motor acting
+ * against the force of gravity).
*/
@SuppressWarnings("MemberName")
public class ElevatorFeedforward {
@@ -19,8 +16,8 @@
public final double ka;
/**
- * Creates a new ElevatorFeedforward with the specified gains. Units of the gain values
- * will dictate units of the computed feedforward.
+ * Creates a new ElevatorFeedforward with the specified gains. Units of the gain values will
+ * dictate units of the computed feedforward.
*
* @param ks The static gain.
* @param kg The gravity gain.
@@ -35,8 +32,8 @@
}
/**
- * Creates a new ElevatorFeedforward with the specified gains. Acceleration gain is
- * defaulted to zero. Units of the gain values will dictate units of the computed feedforward.
+ * Creates a new ElevatorFeedforward with the specified gains. Acceleration gain is defaulted to
+ * zero. Units of the gain values will dictate units of the computed feedforward.
*
* @param ks The static gain.
* @param kg The gravity gain.
@@ -49,7 +46,7 @@
/**
* Calculates the feedforward from the gains and setpoints.
*
- * @param velocity The velocity setpoint.
+ * @param velocity The velocity setpoint.
* @param acceleration The acceleration setpoint.
* @return The computed feedforward.
*/
@@ -58,8 +55,8 @@
}
/**
- * Calculates the feedforward from the gains and velocity setpoint (acceleration is assumed to
- * be zero).
+ * Calculates the feedforward from the gains and velocity setpoint (acceleration is assumed to be
+ * zero).
*
* @param velocity The velocity setpoint.
* @return The computed feedforward.
@@ -72,11 +69,10 @@
// formulas for the methods below:
/**
- * Calculates the maximum achievable velocity given a maximum voltage supply
- * and an acceleration. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the acceleration constraint, and this will give you
- * a simultaneously-achievable velocity constraint.
+ * Calculates the maximum achievable velocity given a maximum voltage supply and an acceleration.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the acceleration constraint, and this will give you a
+ * simultaneously-achievable velocity constraint.
*
* @param maxVoltage The maximum voltage that can be supplied to the elevator.
* @param acceleration The acceleration of the elevator.
@@ -88,11 +84,10 @@
}
/**
- * Calculates the minimum achievable velocity given a maximum voltage supply
- * and an acceleration. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the acceleration constraint, and this will give you
- * a simultaneously-achievable velocity constraint.
+ * Calculates the minimum achievable velocity given a maximum voltage supply and an acceleration.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the acceleration constraint, and this will give you a
+ * simultaneously-achievable velocity constraint.
*
* @param maxVoltage The maximum voltage that can be supplied to the elevator.
* @param acceleration The acceleration of the elevator.
@@ -104,11 +99,10 @@
}
/**
- * Calculates the maximum achievable acceleration given a maximum voltage
- * supply and a velocity. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the velocity constraint, and this will give you
- * a simultaneously-achievable acceleration constraint.
+ * Calculates the maximum achievable acceleration given a maximum voltage supply and a velocity.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the velocity constraint, and this will give you a
+ * simultaneously-achievable acceleration constraint.
*
* @param maxVoltage The maximum voltage that can be supplied to the elevator.
* @param velocity The velocity of the elevator.
@@ -119,11 +113,10 @@
}
/**
- * Calculates the minimum achievable acceleration given a maximum voltage
- * supply and a velocity. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the velocity constraint, and this will give you
- * a simultaneously-achievable acceleration constraint.
+ * Calculates the minimum achievable acceleration given a maximum voltage supply and a velocity.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the velocity constraint, and this will give you a
+ * simultaneously-achievable acceleration constraint.
*
* @param maxVoltage The maximum voltage that can be supplied to the elevator.
* @param velocity The velocity of the elevator.
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/HolonomicDriveController.java b/wpimath/src/main/java/edu/wpi/first/math/controller/HolonomicDriveController.java
new file mode 100644
index 0000000..be813cc
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/HolonomicDriveController.java
@@ -0,0 +1,139 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.trajectory.Trajectory;
+
+/**
+ * This holonomic drive controller can be used to follow trajectories using a holonomic drivetrain
+ * (i.e. swerve or mecanum). Holonomic trajectory following is a much simpler problem to solve
+ * compared to skid-steer style drivetrains because it is possible to individually control forward,
+ * sideways, and angular velocity.
+ *
+ * <p>The holonomic drive controller takes in one PID controller for each direction, forward and
+ * sideways, and one profiled PID controller for the angular direction. Because the heading dynamics
+ * are decoupled from translations, users can specify a custom heading that the drivetrain should
+ * point toward. This heading reference is profiled for smoothness.
+ */
+@SuppressWarnings("MemberName")
+public class HolonomicDriveController {
+ private Pose2d m_poseError = new Pose2d();
+ private Rotation2d m_rotationError = new Rotation2d();
+ private Pose2d m_poseTolerance = new Pose2d();
+ private boolean m_enabled = true;
+
+ private final PIDController m_xController;
+ private final PIDController m_yController;
+ private final ProfiledPIDController m_thetaController;
+
+ private boolean m_firstRun = true;
+
+ /**
+ * Constructs a holonomic drive controller.
+ *
+ * @param xController A PID Controller to respond to error in the field-relative x direction.
+ * @param yController A PID Controller to respond to error in the field-relative y direction.
+ * @param thetaController A profiled PID controller to respond to error in angle.
+ */
+ @SuppressWarnings("ParameterName")
+ public HolonomicDriveController(
+ PIDController xController, PIDController yController, ProfiledPIDController thetaController) {
+ m_xController = xController;
+ m_yController = yController;
+ m_thetaController = thetaController;
+ }
+
+ /**
+ * Returns true if the pose error is within tolerance of the reference.
+ *
+ * @return True if the pose error is within tolerance of the reference.
+ */
+ public boolean atReference() {
+ final var eTranslate = m_poseError.getTranslation();
+ final var eRotate = m_rotationError;
+ final var tolTranslate = m_poseTolerance.getTranslation();
+ final var tolRotate = m_poseTolerance.getRotation();
+ return Math.abs(eTranslate.getX()) < tolTranslate.getX()
+ && Math.abs(eTranslate.getY()) < tolTranslate.getY()
+ && Math.abs(eRotate.getRadians()) < tolRotate.getRadians();
+ }
+
+ /**
+ * Sets the pose error which is considered tolerance for use with atReference().
+ *
+ * @param tolerance The pose error which is tolerable.
+ */
+ public void setTolerance(Pose2d tolerance) {
+ m_poseTolerance = tolerance;
+ }
+
+ /**
+ * Returns the next output of the holonomic drive controller.
+ *
+ * @param currentPose The current pose.
+ * @param poseRef The desired pose.
+ * @param linearVelocityRefMeters The linear velocity reference.
+ * @param angleRef The angular reference.
+ * @return The next output of the holonomic drive controller.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public ChassisSpeeds calculate(
+ Pose2d currentPose, Pose2d poseRef, double linearVelocityRefMeters, Rotation2d angleRef) {
+ // If this is the first run, then we need to reset the theta controller to the current pose's
+ // heading.
+ if (m_firstRun) {
+ m_thetaController.reset(currentPose.getRotation().getRadians());
+ m_firstRun = false;
+ }
+
+ // Calculate feedforward velocities (field-relative).
+ double xFF = linearVelocityRefMeters * poseRef.getRotation().getCos();
+ double yFF = linearVelocityRefMeters * poseRef.getRotation().getSin();
+ double thetaFF =
+ m_thetaController.calculate(currentPose.getRotation().getRadians(), angleRef.getRadians());
+
+ m_poseError = poseRef.relativeTo(currentPose);
+ m_rotationError = angleRef.minus(currentPose.getRotation());
+
+ if (!m_enabled) {
+ return ChassisSpeeds.fromFieldRelativeSpeeds(xFF, yFF, thetaFF, currentPose.getRotation());
+ }
+
+ // Calculate feedback velocities (based on position error).
+ double xFeedback = m_xController.calculate(currentPose.getX(), poseRef.getX());
+ double yFeedback = m_yController.calculate(currentPose.getY(), poseRef.getY());
+
+ // Return next output.
+ return ChassisSpeeds.fromFieldRelativeSpeeds(
+ xFF + xFeedback, yFF + yFeedback, thetaFF, currentPose.getRotation());
+ }
+
+ /**
+ * Returns the next output of the holonomic drive controller.
+ *
+ * @param currentPose The current pose.
+ * @param desiredState The desired trajectory state.
+ * @param angleRef The desired end-angle.
+ * @return The next output of the holonomic drive controller.
+ */
+ public ChassisSpeeds calculate(
+ Pose2d currentPose, Trajectory.State desiredState, Rotation2d angleRef) {
+ return calculate(
+ currentPose, desiredState.poseMeters, desiredState.velocityMetersPerSecond, angleRef);
+ }
+
+ /**
+ * Enables and disables the controller for troubleshooting problems. When calculate() is called on
+ * a disabled controller, only feedforward values are returned.
+ *
+ * @param enabled If the controller is enabled or not.
+ */
+ public void setEnabled(boolean enabled) {
+ m_enabled = enabled;
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/LinearPlantInversionFeedforward.java b/wpimath/src/main/java/edu/wpi/first/math/controller/LinearPlantInversionFeedforward.java
similarity index 64%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/controller/LinearPlantInversionFeedforward.java
rename to wpimath/src/main/java/edu/wpi/first/math/controller/LinearPlantInversionFeedforward.java
index a4a00ee..627c272 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/LinearPlantInversionFeedforward.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/LinearPlantInversionFeedforward.java
@@ -1,20 +1,16 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.controller;
+package edu.wpi.first.math.controller;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.system.Discretization;
+import edu.wpi.first.math.system.LinearSystem;
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
/**
* Constructs a plant inversion model-based feedforward from a {@link LinearSystem}.
*
@@ -25,17 +21,13 @@
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName", "MemberName", "ClassTypeParameterName"})
-public class LinearPlantInversionFeedforward<States extends Num, Inputs extends Num,
- Outputs extends Num> {
- /**
- * The current reference state.
- */
+public class LinearPlantInversionFeedforward<
+ States extends Num, Inputs extends Num, Outputs extends Num> {
+ /** The current reference state. */
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_r;
- /**
- * The computed feedforward.
- */
+ /** The computed feedforward. */
private Matrix<Inputs, N1> m_uff;
@SuppressWarnings("MemberName")
@@ -47,26 +39,24 @@
/**
* Constructs a feedforward with the given plant.
*
- * @param plant The plant being controlled.
+ * @param plant The plant being controlled.
* @param dtSeconds Discretization timestep.
*/
public LinearPlantInversionFeedforward(
- LinearSystem<States, Inputs, Outputs> plant,
- double dtSeconds
- ) {
+ LinearSystem<States, Inputs, Outputs> plant, double dtSeconds) {
this(plant.getA(), plant.getB(), dtSeconds);
}
/**
* Constructs a feedforward with the given coefficients.
*
- * @param A Continuous system matrix of the plant being controlled.
- * @param B Continuous input matrix of the plant being controlled.
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
* @param dtSeconds Discretization timestep.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
- public LinearPlantInversionFeedforward(Matrix<States, States> A, Matrix<States, Inputs> B,
- double dtSeconds) {
+ public LinearPlantInversionFeedforward(
+ Matrix<States, States> A, Matrix<States, Inputs> B, double dtSeconds) {
var discABPair = Discretization.discretizeAB(A, B, dtSeconds);
this.m_A = discABPair.getFirst();
this.m_B = discABPair.getSecond();
@@ -74,7 +64,7 @@
m_r = new Matrix<>(new SimpleMatrix(B.getNumRows(), 1));
m_uff = new Matrix<>(new SimpleMatrix(B.getNumCols(), 1));
- reset(m_r);
+ reset();
}
/**
@@ -90,7 +80,6 @@
* Returns an element of the previously calculated feedforward.
*
* @param row Row of uff.
- *
* @return The row of the calculated feedforward.
*/
public double getUff(int row) {
@@ -110,7 +99,6 @@
* Returns an element of the current reference vector r.
*
* @param row Row of r.
- *
* @return The row of the current reference vector.
*/
public double getR(int row) {
@@ -127,25 +115,21 @@
m_uff.fill(0.0);
}
- /**
- * Resets the feedforward with a zero initial state vector.
- */
+ /** Resets the feedforward with a zero initial state vector. */
public void reset() {
m_r.fill(0.0);
m_uff.fill(0.0);
}
/**
- * Calculate the feedforward with only the desired
- * future reference. This uses the internally stored "current"
- * reference.
+ * Calculate the feedforward with only the desired future reference. This uses the internally
+ * stored "current" reference.
*
- * <p>If this method is used the initial state of the system is the one
- * set using {@link LinearPlantInversionFeedforward#reset(Matrix)}.
- * If the initial state is not set it defaults to a zero vector.
+ * <p>If this method is used the initial state of the system is the one set using {@link
+ * LinearPlantInversionFeedforward#reset(Matrix)}. If the initial state is not set it defaults to
+ * a zero vector.
*
* @param nextR The reference state of the future timestep (k + dt).
- *
* @return The calculated feedforward.
*/
public Matrix<Inputs, N1> calculate(Matrix<States, N1> nextR) {
@@ -155,9 +139,8 @@
/**
* Calculate the feedforward with current and future reference vectors.
*
- * @param r The reference state of the current timestep (k).
+ * @param r The reference state of the current timestep (k).
* @param nextR The reference state of the future timestep (k + dt).
- *
* @return The calculated feedforward.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/LinearQuadraticRegulator.java b/wpimath/src/main/java/edu/wpi/first/math/controller/LinearQuadraticRegulator.java
new file mode 100644
index 0000000..cf4b26d
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/LinearQuadraticRegulator.java
@@ -0,0 +1,281 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.Drake;
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.Vector;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.system.Discretization;
+import edu.wpi.first.math.system.LinearSystem;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * Contains the controller coefficients and logic for a linear-quadratic regulator (LQR). LQRs use
+ * the control law u = K(r - x).
+ *
+ * <p>For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ */
+@SuppressWarnings("ClassTypeParameterName")
+public class LinearQuadraticRegulator<States extends Num, Inputs extends Num, Outputs extends Num> {
+ /** The current reference state. */
+ @SuppressWarnings("MemberName")
+ private Matrix<States, N1> m_r;
+
+ /** The computed and capped controller output. */
+ @SuppressWarnings("MemberName")
+ private Matrix<Inputs, N1> m_u;
+
+ // Controller gain.
+ @SuppressWarnings("MemberName")
+ private Matrix<Inputs, States> m_K;
+
+ /**
+ * Constructs a controller with the given coefficients and plant. Rho is defaulted to 1.
+ *
+ * @param plant The plant being controlled.
+ * @param qelms The maximum desired error tolerance for each state.
+ * @param relms The maximum desired control effort for each input.
+ * @param dtSeconds Discretization timestep.
+ */
+ public LinearQuadraticRegulator(
+ LinearSystem<States, Inputs, Outputs> plant,
+ Vector<States> qelms,
+ Vector<Inputs> relms,
+ double dtSeconds) {
+ this(
+ plant.getA(),
+ plant.getB(),
+ StateSpaceUtil.makeCostMatrix(qelms),
+ StateSpaceUtil.makeCostMatrix(relms),
+ dtSeconds);
+ }
+
+ /**
+ * Constructs a controller with the given coefficients and plant.
+ *
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param qelms The maximum desired error tolerance for each state.
+ * @param relms The maximum desired control effort for each input.
+ * @param dtSeconds Discretization timestep.
+ */
+ @SuppressWarnings({"ParameterName", "LocalVariableName"})
+ public LinearQuadraticRegulator(
+ Matrix<States, States> A,
+ Matrix<States, Inputs> B,
+ Vector<States> qelms,
+ Vector<Inputs> relms,
+ double dtSeconds) {
+ this(
+ A,
+ B,
+ StateSpaceUtil.makeCostMatrix(qelms),
+ StateSpaceUtil.makeCostMatrix(relms),
+ dtSeconds);
+ }
+
+ /**
+ * Constructs a controller with the given coefficients and plant.
+ *
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param Q The state cost matrix.
+ * @param R The input cost matrix.
+ * @param dtSeconds Discretization timestep.
+ */
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ public LinearQuadraticRegulator(
+ Matrix<States, States> A,
+ Matrix<States, Inputs> B,
+ Matrix<States, States> Q,
+ Matrix<Inputs, Inputs> R,
+ double dtSeconds) {
+ var discABPair = Discretization.discretizeAB(A, B, dtSeconds);
+ var discA = discABPair.getFirst();
+ var discB = discABPair.getSecond();
+
+ if (!StateSpaceUtil.isStabilizable(discA, discB)) {
+ var builder = new StringBuilder("The system passed to the LQR is uncontrollable!\n\nA =\n");
+ builder.append(discA.getStorage().toString());
+ builder.append("\nB =\n");
+ builder.append(discB.getStorage().toString());
+ builder.append("\n");
+
+ var msg = builder.toString();
+ MathSharedStore.reportError(msg, Thread.currentThread().getStackTrace());
+ throw new IllegalArgumentException(msg);
+ }
+
+ var S = Drake.discreteAlgebraicRiccatiEquation(discA, discB, Q, R);
+
+ // K = (BᵀSB + R)⁻¹BᵀSA
+ var temp = discB.transpose().times(S).times(discB).plus(R);
+ m_K = temp.solve(discB.transpose().times(S).times(discA));
+
+ m_r = new Matrix<>(new SimpleMatrix(B.getNumRows(), 1));
+ m_u = new Matrix<>(new SimpleMatrix(B.getNumCols(), 1));
+
+ reset();
+ }
+
+ /**
+ * Constructs a controller with the given coefficients and plant.
+ *
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param Q The state cost matrix.
+ * @param R The input cost matrix.
+ * @param N The state-input cross-term cost matrix.
+ * @param dtSeconds Discretization timestep.
+ */
+ @SuppressWarnings({"ParameterName", "LocalVariableName"})
+ public LinearQuadraticRegulator(
+ Matrix<States, States> A,
+ Matrix<States, Inputs> B,
+ Matrix<States, States> Q,
+ Matrix<Inputs, Inputs> R,
+ Matrix<States, Inputs> N,
+ double dtSeconds) {
+ var discABPair = Discretization.discretizeAB(A, B, dtSeconds);
+ var discA = discABPair.getFirst();
+ var discB = discABPair.getSecond();
+
+ var S = Drake.discreteAlgebraicRiccatiEquation(discA, discB, Q, R, N);
+
+ // K = (BᵀSB + R)⁻¹(BᵀSA + Nᵀ)
+ var temp = discB.transpose().times(S).times(discB).plus(R);
+ m_K = temp.solve(discB.transpose().times(S).times(discA).plus(N.transpose()));
+
+ m_r = new Matrix<>(new SimpleMatrix(B.getNumRows(), 1));
+ m_u = new Matrix<>(new SimpleMatrix(B.getNumCols(), 1));
+
+ reset();
+ }
+
+ /**
+ * Constructs a controller with the given coefficients and plant.
+ *
+ * @param states The number of states.
+ * @param inputs The number of inputs.
+ * @param k The gain matrix.
+ */
+ @SuppressWarnings("ParameterName")
+ public LinearQuadraticRegulator(
+ Nat<States> states, Nat<Inputs> inputs, Matrix<Inputs, States> k) {
+ m_K = k;
+
+ m_r = new Matrix<>(states, Nat.N1());
+ m_u = new Matrix<>(inputs, Nat.N1());
+
+ reset();
+ }
+
+ /**
+ * Returns the control input vector u.
+ *
+ * @return The control input.
+ */
+ public Matrix<Inputs, N1> getU() {
+ return m_u;
+ }
+
+ /**
+ * Returns an element of the control input vector u.
+ *
+ * @param row Row of u.
+ * @return The row of the control input vector.
+ */
+ public double getU(int row) {
+ return m_u.get(row, 0);
+ }
+
+ /**
+ * Returns the reference vector r.
+ *
+ * @return The reference vector.
+ */
+ public Matrix<States, N1> getR() {
+ return m_r;
+ }
+
+ /**
+ * Returns an element of the reference vector r.
+ *
+ * @param row Row of r.
+ * @return The row of the reference vector.
+ */
+ public double getR(int row) {
+ return m_r.get(row, 0);
+ }
+
+ /**
+ * Returns the controller matrix K.
+ *
+ * @return the controller matrix K.
+ */
+ public Matrix<Inputs, States> getK() {
+ return m_K;
+ }
+
+ /** Resets the controller. */
+ public void reset() {
+ m_r.fill(0.0);
+ m_u.fill(0.0);
+ }
+
+ /**
+ * Returns the next output of the controller.
+ *
+ * @param x The current state x.
+ * @return The next controller output.
+ */
+ @SuppressWarnings("ParameterName")
+ public Matrix<Inputs, N1> calculate(Matrix<States, N1> x) {
+ m_u = m_K.times(m_r.minus(x));
+ return m_u;
+ }
+
+ /**
+ * Returns the next output of the controller.
+ *
+ * @param x The current state x.
+ * @param nextR the next reference vector r.
+ * @return The next controller output.
+ */
+ @SuppressWarnings("ParameterName")
+ public Matrix<Inputs, N1> calculate(Matrix<States, N1> x, Matrix<States, N1> nextR) {
+ m_r = nextR;
+ return calculate(x);
+ }
+
+ /**
+ * Adjusts LQR controller gain to compensate for a pure time delay in the input.
+ *
+ * <p>Linear-Quadratic regulator controller gains tend to be aggressive. If sensor measurements
+ * are time-delayed too long, the LQR may be unstable. However, if we know the amount of delay, we
+ * can compute the control based on where the system will be after the time delay.
+ *
+ * <p>See https://file.tavsys.net/control/controls-engineering-in-frc.pdf appendix C.4 for a
+ * derivation.
+ *
+ * @param plant The plant being controlled.
+ * @param dtSeconds Discretization timestep in seconds.
+ * @param inputDelaySeconds Input time delay in seconds.
+ */
+ public void latencyCompensate(
+ LinearSystem<States, Inputs, Outputs> plant, double dtSeconds, double inputDelaySeconds) {
+ var discABPair = Discretization.discretizeAB(plant.getA(), plant.getB(), dtSeconds);
+ var discA = discABPair.getFirst();
+ var discB = discABPair.getSecond();
+
+ m_K = m_K.times((discA.minus(discB.times(m_K))).pow(inputDelaySeconds / dtSeconds));
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/PIDController.java b/wpimath/src/main/java/edu/wpi/first/math/controller/PIDController.java
new file mode 100644
index 0000000..12c4175
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/PIDController.java
@@ -0,0 +1,355 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.MathUtil;
+import edu.wpi.first.util.sendable.Sendable;
+import edu.wpi.first.util.sendable.SendableBuilder;
+import edu.wpi.first.util.sendable.SendableRegistry;
+
+/** Implements a PID control loop. */
+public class PIDController implements Sendable, AutoCloseable {
+ private static int instances;
+
+ // Factor for "proportional" control
+ private double m_kp;
+
+ // Factor for "integral" control
+ private double m_ki;
+
+ // Factor for "derivative" control
+ private double m_kd;
+
+ // The period (in seconds) of the loop that calls the controller
+ private final double m_period;
+
+ private double m_maximumIntegral = 1.0;
+
+ private double m_minimumIntegral = -1.0;
+
+ private double m_maximumInput;
+
+ private double m_minimumInput;
+
+ // Do the endpoints wrap around? eg. Absolute encoder
+ private boolean m_continuous;
+
+ // The error at the time of the most recent call to calculate()
+ private double m_positionError;
+ private double m_velocityError;
+
+ // The error at the time of the second-most-recent call to calculate() (used to compute velocity)
+ private double m_prevError;
+
+ // The sum of the errors for use in the integral calc
+ private double m_totalError;
+
+ // The error that is considered at setpoint.
+ private double m_positionTolerance = 0.05;
+ private double m_velocityTolerance = Double.POSITIVE_INFINITY;
+
+ private double m_setpoint;
+ private double m_measurement;
+
+ /**
+ * Allocates a PIDController with the given constants for kp, ki, and kd and a default period of
+ * 0.02 seconds.
+ *
+ * @param kp The proportional coefficient.
+ * @param ki The integral coefficient.
+ * @param kd The derivative coefficient.
+ */
+ public PIDController(double kp, double ki, double kd) {
+ this(kp, ki, kd, 0.02);
+ }
+
+ /**
+ * Allocates a PIDController with the given constants for kp, ki, and kd.
+ *
+ * @param kp The proportional coefficient.
+ * @param ki The integral coefficient.
+ * @param kd The derivative coefficient.
+ * @param period The period between controller updates in seconds. Must be non-zero and positive.
+ */
+ public PIDController(double kp, double ki, double kd, double period) {
+ m_kp = kp;
+ m_ki = ki;
+ m_kd = kd;
+
+ if (period <= 0) {
+ throw new IllegalArgumentException("Controller period must be a non-zero positive number!");
+ }
+ m_period = period;
+
+ instances++;
+ SendableRegistry.addLW(this, "PIDController", instances);
+
+ MathSharedStore.reportUsage(MathUsageId.kController_PIDController2, instances);
+ }
+
+ @Override
+ public void close() {
+ SendableRegistry.remove(this);
+ }
+
+ /**
+ * Sets the PID Controller gain parameters.
+ *
+ * <p>Set the proportional, integral, and differential coefficients.
+ *
+ * @param kp The proportional coefficient.
+ * @param ki The integral coefficient.
+ * @param kd The derivative coefficient.
+ */
+ public void setPID(double kp, double ki, double kd) {
+ m_kp = kp;
+ m_ki = ki;
+ m_kd = kd;
+ }
+
+ /**
+ * Sets the Proportional coefficient of the PID controller gain.
+ *
+ * @param kp proportional coefficient
+ */
+ public void setP(double kp) {
+ m_kp = kp;
+ }
+
+ /**
+ * Sets the Integral coefficient of the PID controller gain.
+ *
+ * @param ki integral coefficient
+ */
+ public void setI(double ki) {
+ m_ki = ki;
+ }
+
+ /**
+ * Sets the Differential coefficient of the PID controller gain.
+ *
+ * @param kd differential coefficient
+ */
+ public void setD(double kd) {
+ m_kd = kd;
+ }
+
+ /**
+ * Get the Proportional coefficient.
+ *
+ * @return proportional coefficient
+ */
+ public double getP() {
+ return m_kp;
+ }
+
+ /**
+ * Get the Integral coefficient.
+ *
+ * @return integral coefficient
+ */
+ public double getI() {
+ return m_ki;
+ }
+
+ /**
+ * Get the Differential coefficient.
+ *
+ * @return differential coefficient
+ */
+ public double getD() {
+ return m_kd;
+ }
+
+ /**
+ * Returns the period of this controller.
+ *
+ * @return the period of the controller.
+ */
+ public double getPeriod() {
+ return m_period;
+ }
+
+ /**
+ * Sets the setpoint for the PIDController.
+ *
+ * @param setpoint The desired setpoint.
+ */
+ public void setSetpoint(double setpoint) {
+ m_setpoint = setpoint;
+ }
+
+ /**
+ * Returns the current setpoint of the PIDController.
+ *
+ * @return The current setpoint.
+ */
+ public double getSetpoint() {
+ return m_setpoint;
+ }
+
+ /**
+ * Returns true if the error is within the tolerance of the setpoint.
+ *
+ * <p>This will return false until at least one input value has been computed.
+ *
+ * @return Whether the error is within the acceptable bounds.
+ */
+ public boolean atSetpoint() {
+ double positionError;
+ if (m_continuous) {
+ double errorBound = (m_maximumInput - m_minimumInput) / 2.0;
+ positionError = MathUtil.inputModulus(m_setpoint - m_measurement, -errorBound, errorBound);
+ } else {
+ positionError = m_setpoint - m_measurement;
+ }
+
+ double velocityError = (positionError - m_prevError) / m_period;
+
+ return Math.abs(positionError) < m_positionTolerance
+ && Math.abs(velocityError) < m_velocityTolerance;
+ }
+
+ /**
+ * Enables continuous input.
+ *
+ * <p>Rather then using the max and min input range as constraints, it considers them to be the
+ * same point and automatically calculates the shortest route to the setpoint.
+ *
+ * @param minimumInput The minimum value expected from the input.
+ * @param maximumInput The maximum value expected from the input.
+ */
+ public void enableContinuousInput(double minimumInput, double maximumInput) {
+ m_continuous = true;
+ m_minimumInput = minimumInput;
+ m_maximumInput = maximumInput;
+ }
+
+ /** Disables continuous input. */
+ public void disableContinuousInput() {
+ m_continuous = false;
+ }
+
+ /**
+ * Returns true if continuous input is enabled.
+ *
+ * @return True if continuous input is enabled.
+ */
+ public boolean isContinuousInputEnabled() {
+ return m_continuous;
+ }
+
+ /**
+ * Sets the minimum and maximum values for the integrator.
+ *
+ * <p>When the cap is reached, the integrator value is added to the controller output rather than
+ * the integrator value times the integral gain.
+ *
+ * @param minimumIntegral The minimum value of the integrator.
+ * @param maximumIntegral The maximum value of the integrator.
+ */
+ public void setIntegratorRange(double minimumIntegral, double maximumIntegral) {
+ m_minimumIntegral = minimumIntegral;
+ m_maximumIntegral = maximumIntegral;
+ }
+
+ /**
+ * Sets the error which is considered tolerable for use with atSetpoint().
+ *
+ * @param positionTolerance Position error which is tolerable.
+ */
+ public void setTolerance(double positionTolerance) {
+ setTolerance(positionTolerance, Double.POSITIVE_INFINITY);
+ }
+
+ /**
+ * Sets the error which is considered tolerable for use with atSetpoint().
+ *
+ * @param positionTolerance Position error which is tolerable.
+ * @param velocityTolerance Velocity error which is tolerable.
+ */
+ public void setTolerance(double positionTolerance, double velocityTolerance) {
+ m_positionTolerance = positionTolerance;
+ m_velocityTolerance = velocityTolerance;
+ }
+
+ /**
+ * Returns the difference between the setpoint and the measurement.
+ *
+ * @return The error.
+ */
+ public double getPositionError() {
+ return m_positionError;
+ }
+
+ /**
+ * Returns the velocity error.
+ *
+ * @return The velocity error.
+ */
+ public double getVelocityError() {
+ return m_velocityError;
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param setpoint The new setpoint of the controller.
+ * @return The next controller output.
+ */
+ public double calculate(double measurement, double setpoint) {
+ // Set setpoint to provided value
+ setSetpoint(setpoint);
+ return calculate(measurement);
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @return The next controller output.
+ */
+ public double calculate(double measurement) {
+ m_measurement = measurement;
+ m_prevError = m_positionError;
+
+ if (m_continuous) {
+ double errorBound = (m_maximumInput - m_minimumInput) / 2.0;
+ m_positionError = MathUtil.inputModulus(m_setpoint - m_measurement, -errorBound, errorBound);
+ } else {
+ m_positionError = m_setpoint - measurement;
+ }
+
+ m_velocityError = (m_positionError - m_prevError) / m_period;
+
+ if (m_ki != 0) {
+ m_totalError =
+ MathUtil.clamp(
+ m_totalError + m_positionError * m_period,
+ m_minimumIntegral / m_ki,
+ m_maximumIntegral / m_ki);
+ }
+
+ return m_kp * m_positionError + m_ki * m_totalError + m_kd * m_velocityError;
+ }
+
+ /** Resets the previous error and the integral term. */
+ public void reset() {
+ m_prevError = 0;
+ m_totalError = 0;
+ }
+
+ @Override
+ public void initSendable(SendableBuilder builder) {
+ builder.setSmartDashboardType("PIDController");
+ builder.addDoubleProperty("p", this::getP, this::setP);
+ builder.addDoubleProperty("i", this::getI, this::setI);
+ builder.addDoubleProperty("d", this::getD, this::setD);
+ builder.addDoubleProperty("setpoint", this::getSetpoint, this::setSetpoint);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/ProfiledPIDController.java b/wpimath/src/main/java/edu/wpi/first/math/controller/ProfiledPIDController.java
new file mode 100644
index 0000000..3ebcbd8
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/ProfiledPIDController.java
@@ -0,0 +1,381 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.MathUtil;
+import edu.wpi.first.math.trajectory.TrapezoidProfile;
+import edu.wpi.first.util.sendable.Sendable;
+import edu.wpi.first.util.sendable.SendableBuilder;
+
+/**
+ * Implements a PID control loop whose setpoint is constrained by a trapezoid profile. Users should
+ * call reset() when they first start running the controller to avoid unwanted behavior.
+ */
+public class ProfiledPIDController implements Sendable {
+ private static int instances;
+
+ private PIDController m_controller;
+ private double m_minimumInput;
+ private double m_maximumInput;
+ private TrapezoidProfile.State m_goal = new TrapezoidProfile.State();
+ private TrapezoidProfile.State m_setpoint = new TrapezoidProfile.State();
+ private TrapezoidProfile.Constraints m_constraints;
+
+ /**
+ * Allocates a ProfiledPIDController with the given constants for Kp, Ki, and Kd.
+ *
+ * @param Kp The proportional coefficient.
+ * @param Ki The integral coefficient.
+ * @param Kd The derivative coefficient.
+ * @param constraints Velocity and acceleration constraints for goal.
+ */
+ @SuppressWarnings("ParameterName")
+ public ProfiledPIDController(
+ double Kp, double Ki, double Kd, TrapezoidProfile.Constraints constraints) {
+ this(Kp, Ki, Kd, constraints, 0.02);
+ }
+
+ /**
+ * Allocates a ProfiledPIDController with the given constants for Kp, Ki, and Kd.
+ *
+ * @param Kp The proportional coefficient.
+ * @param Ki The integral coefficient.
+ * @param Kd The derivative coefficient.
+ * @param constraints Velocity and acceleration constraints for goal.
+ * @param period The period between controller updates in seconds. The default is 0.02 seconds.
+ */
+ @SuppressWarnings("ParameterName")
+ public ProfiledPIDController(
+ double Kp, double Ki, double Kd, TrapezoidProfile.Constraints constraints, double period) {
+ m_controller = new PIDController(Kp, Ki, Kd, period);
+ m_constraints = constraints;
+ instances++;
+ MathSharedStore.reportUsage(MathUsageId.kController_ProfiledPIDController, instances);
+ }
+
+ /**
+ * Sets the PID Controller gain parameters.
+ *
+ * <p>Sets the proportional, integral, and differential coefficients.
+ *
+ * @param Kp Proportional coefficient
+ * @param Ki Integral coefficient
+ * @param Kd Differential coefficient
+ */
+ @SuppressWarnings("ParameterName")
+ public void setPID(double Kp, double Ki, double Kd) {
+ m_controller.setPID(Kp, Ki, Kd);
+ }
+
+ /**
+ * Sets the proportional coefficient of the PID controller gain.
+ *
+ * @param Kp proportional coefficient
+ */
+ @SuppressWarnings("ParameterName")
+ public void setP(double Kp) {
+ m_controller.setP(Kp);
+ }
+
+ /**
+ * Sets the integral coefficient of the PID controller gain.
+ *
+ * @param Ki integral coefficient
+ */
+ @SuppressWarnings("ParameterName")
+ public void setI(double Ki) {
+ m_controller.setI(Ki);
+ }
+
+ /**
+ * Sets the differential coefficient of the PID controller gain.
+ *
+ * @param Kd differential coefficient
+ */
+ @SuppressWarnings("ParameterName")
+ public void setD(double Kd) {
+ m_controller.setD(Kd);
+ }
+
+ /**
+ * Gets the proportional coefficient.
+ *
+ * @return proportional coefficient
+ */
+ public double getP() {
+ return m_controller.getP();
+ }
+
+ /**
+ * Gets the integral coefficient.
+ *
+ * @return integral coefficient
+ */
+ public double getI() {
+ return m_controller.getI();
+ }
+
+ /**
+ * Gets the differential coefficient.
+ *
+ * @return differential coefficient
+ */
+ public double getD() {
+ return m_controller.getD();
+ }
+
+ /**
+ * Gets the period of this controller.
+ *
+ * @return The period of the controller.
+ */
+ public double getPeriod() {
+ return m_controller.getPeriod();
+ }
+
+ /**
+ * Sets the goal for the ProfiledPIDController.
+ *
+ * @param goal The desired goal state.
+ */
+ public void setGoal(TrapezoidProfile.State goal) {
+ m_goal = goal;
+ }
+
+ /**
+ * Sets the goal for the ProfiledPIDController.
+ *
+ * @param goal The desired goal position.
+ */
+ public void setGoal(double goal) {
+ m_goal = new TrapezoidProfile.State(goal, 0);
+ }
+
+ /**
+ * Gets the goal for the ProfiledPIDController.
+ *
+ * @return The goal.
+ */
+ public TrapezoidProfile.State getGoal() {
+ return m_goal;
+ }
+
+ /**
+ * Returns true if the error is within the tolerance of the error.
+ *
+ * <p>This will return false until at least one input value has been computed.
+ *
+ * @return True if the error is within the tolerance of the error.
+ */
+ public boolean atGoal() {
+ return atSetpoint() && m_goal.equals(m_setpoint);
+ }
+
+ /**
+ * Set velocity and acceleration constraints for goal.
+ *
+ * @param constraints Velocity and acceleration constraints for goal.
+ */
+ public void setConstraints(TrapezoidProfile.Constraints constraints) {
+ m_constraints = constraints;
+ }
+
+ /**
+ * Returns the current setpoint of the ProfiledPIDController.
+ *
+ * @return The current setpoint.
+ */
+ public TrapezoidProfile.State getSetpoint() {
+ return m_setpoint;
+ }
+
+ /**
+ * Returns true if the error is within the tolerance of the error.
+ *
+ * <p>This will return false until at least one input value has been computed.
+ *
+ * @return True if the error is within the tolerance of the error.
+ */
+ public boolean atSetpoint() {
+ return m_controller.atSetpoint();
+ }
+
+ /**
+ * Enables continuous input.
+ *
+ * <p>Rather then using the max and min input range as constraints, it considers them to be the
+ * same point and automatically calculates the shortest route to the setpoint.
+ *
+ * @param minimumInput The minimum value expected from the input.
+ * @param maximumInput The maximum value expected from the input.
+ */
+ public void enableContinuousInput(double minimumInput, double maximumInput) {
+ m_controller.enableContinuousInput(minimumInput, maximumInput);
+ m_minimumInput = minimumInput;
+ m_maximumInput = maximumInput;
+ }
+
+ /** Disables continuous input. */
+ public void disableContinuousInput() {
+ m_controller.disableContinuousInput();
+ }
+
+ /**
+ * Sets the minimum and maximum values for the integrator.
+ *
+ * <p>When the cap is reached, the integrator value is added to the controller output rather than
+ * the integrator value times the integral gain.
+ *
+ * @param minimumIntegral The minimum value of the integrator.
+ * @param maximumIntegral The maximum value of the integrator.
+ */
+ public void setIntegratorRange(double minimumIntegral, double maximumIntegral) {
+ m_controller.setIntegratorRange(minimumIntegral, maximumIntegral);
+ }
+
+ /**
+ * Sets the error which is considered tolerable for use with atSetpoint().
+ *
+ * @param positionTolerance Position error which is tolerable.
+ */
+ public void setTolerance(double positionTolerance) {
+ setTolerance(positionTolerance, Double.POSITIVE_INFINITY);
+ }
+
+ /**
+ * Sets the error which is considered tolerable for use with atSetpoint().
+ *
+ * @param positionTolerance Position error which is tolerable.
+ * @param velocityTolerance Velocity error which is tolerable.
+ */
+ public void setTolerance(double positionTolerance, double velocityTolerance) {
+ m_controller.setTolerance(positionTolerance, velocityTolerance);
+ }
+
+ /**
+ * Returns the difference between the setpoint and the measurement.
+ *
+ * @return The error.
+ */
+ public double getPositionError() {
+ return m_controller.getPositionError();
+ }
+
+ /**
+ * Returns the change in error per second.
+ *
+ * @return The change in error per second.
+ */
+ public double getVelocityError() {
+ return m_controller.getVelocityError();
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @return The controller's next output.
+ */
+ public double calculate(double measurement) {
+ if (m_controller.isContinuousInputEnabled()) {
+ // Get error which is smallest distance between goal and measurement
+ double errorBound = (m_maximumInput - m_minimumInput) / 2.0;
+ double goalMinDistance =
+ MathUtil.inputModulus(m_goal.position - measurement, -errorBound, errorBound);
+ double setpointMinDistance =
+ MathUtil.inputModulus(m_setpoint.position - measurement, -errorBound, errorBound);
+
+ // Recompute the profile goal with the smallest error, thus giving the shortest path. The goal
+ // may be outside the input range after this operation, but that's OK because the controller
+ // will still go there and report an error of zero. In other words, the setpoint only needs to
+ // be offset from the measurement by the input range modulus; they don't need to be equal.
+ m_goal.position = goalMinDistance + measurement;
+ m_setpoint.position = setpointMinDistance + measurement;
+ }
+
+ var profile = new TrapezoidProfile(m_constraints, m_goal, m_setpoint);
+ m_setpoint = profile.calculate(getPeriod());
+ return m_controller.calculate(measurement, m_setpoint.position);
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param goal The new goal of the controller.
+ * @return The controller's next output.
+ */
+ public double calculate(double measurement, TrapezoidProfile.State goal) {
+ setGoal(goal);
+ return calculate(measurement);
+ }
+
+ /**
+ * Returns the next output of the PIDController.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param goal The new goal of the controller.
+ * @return The controller's next output.
+ */
+ public double calculate(double measurement, double goal) {
+ setGoal(goal);
+ return calculate(measurement);
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param goal The new goal of the controller.
+ * @param constraints Velocity and acceleration constraints for goal.
+ * @return The controller's next output.
+ */
+ public double calculate(
+ double measurement, TrapezoidProfile.State goal, TrapezoidProfile.Constraints constraints) {
+ setConstraints(constraints);
+ return calculate(measurement, goal);
+ }
+
+ /**
+ * Reset the previous error and the integral term.
+ *
+ * @param measurement The current measured State of the system.
+ */
+ public void reset(TrapezoidProfile.State measurement) {
+ m_controller.reset();
+ m_setpoint = measurement;
+ }
+
+ /**
+ * Reset the previous error and the integral term.
+ *
+ * @param measuredPosition The current measured position of the system.
+ * @param measuredVelocity The current measured velocity of the system.
+ */
+ public void reset(double measuredPosition, double measuredVelocity) {
+ reset(new TrapezoidProfile.State(measuredPosition, measuredVelocity));
+ }
+
+ /**
+ * Reset the previous error and the integral term.
+ *
+ * @param measuredPosition The current measured position of the system. The velocity is assumed to
+ * be zero.
+ */
+ public void reset(double measuredPosition) {
+ reset(measuredPosition, 0.0);
+ }
+
+ @Override
+ public void initSendable(SendableBuilder builder) {
+ builder.setSmartDashboardType("ProfiledPIDController");
+ builder.addDoubleProperty("p", this::getP, this::setP);
+ builder.addDoubleProperty("i", this::getI, this::setI);
+ builder.addDoubleProperty("d", this::getD, this::setD);
+ builder.addDoubleProperty("goal", () -> getGoal().position, this::setGoal);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/RamseteController.java b/wpimath/src/main/java/edu/wpi/first/math/controller/RamseteController.java
new file mode 100644
index 0000000..0be592e
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/RamseteController.java
@@ -0,0 +1,172 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.trajectory.Trajectory;
+
+/**
+ * Ramsete is a nonlinear time-varying feedback controller for unicycle models that drives the model
+ * to a desired pose along a two-dimensional trajectory. Why would we need a nonlinear control law
+ * in addition to the linear ones we have used so far like PID? If we use the original approach with
+ * PID controllers for left and right position and velocity states, the controllers only deal with
+ * the local pose. If the robot deviates from the path, there is no way for the controllers to
+ * correct and the robot may not reach the desired global pose. This is due to multiple endpoints
+ * existing for the robot which have the same encoder path arc lengths.
+ *
+ * <p>Instead of using wheel path arc lengths (which are in the robot's local coordinate frame),
+ * nonlinear controllers like pure pursuit and Ramsete use global pose. The controller uses this
+ * extra information to guide a linear reference tracker like the PID controllers back in by
+ * adjusting the references of the PID controllers.
+ *
+ * <p>The paper "Control of Wheeled Mobile Robots: An Experimental Overview" describes a nonlinear
+ * controller for a wheeled vehicle with unicycle-like kinematics; a global pose consisting of x, y,
+ * and theta; and a desired pose consisting of x_d, y_d, and theta_d. We call it Ramsete because
+ * that's the acronym for the title of the book it came from in Italian ("Robotica Articolata e
+ * Mobile per i SErvizi e le TEcnologie").
+ *
+ * <p>See <a href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">Controls
+ * Engineering in the FIRST Robotics Competition</a> section on Ramsete unicycle controller for a
+ * derivation and analysis.
+ */
+public class RamseteController {
+ @SuppressWarnings("MemberName")
+ private final double m_b;
+
+ @SuppressWarnings("MemberName")
+ private final double m_zeta;
+
+ private Pose2d m_poseError = new Pose2d();
+ private Pose2d m_poseTolerance = new Pose2d();
+ private boolean m_enabled = true;
+
+ /**
+ * Construct a Ramsete unicycle controller.
+ *
+ * @param b Tuning parameter (b > 0) for which larger values make convergence more aggressive
+ * like a proportional term.
+ * @param zeta Tuning parameter (0 < zeta < 1) for which larger values provide more damping
+ * in response.
+ */
+ @SuppressWarnings("ParameterName")
+ public RamseteController(double b, double zeta) {
+ m_b = b;
+ m_zeta = zeta;
+ }
+
+ /**
+ * Construct a Ramsete unicycle controller. The default arguments for b and zeta of 2.0 and 0.7
+ * have been well-tested to produce desirable results.
+ */
+ public RamseteController() {
+ this(2.0, 0.7);
+ }
+
+ /**
+ * Returns true if the pose error is within tolerance of the reference.
+ *
+ * @return True if the pose error is within tolerance of the reference.
+ */
+ public boolean atReference() {
+ final var eTranslate = m_poseError.getTranslation();
+ final var eRotate = m_poseError.getRotation();
+ final var tolTranslate = m_poseTolerance.getTranslation();
+ final var tolRotate = m_poseTolerance.getRotation();
+ return Math.abs(eTranslate.getX()) < tolTranslate.getX()
+ && Math.abs(eTranslate.getY()) < tolTranslate.getY()
+ && Math.abs(eRotate.getRadians()) < tolRotate.getRadians();
+ }
+
+ /**
+ * Sets the pose error which is considered tolerable for use with atReference().
+ *
+ * @param poseTolerance Pose error which is tolerable.
+ */
+ public void setTolerance(Pose2d poseTolerance) {
+ m_poseTolerance = poseTolerance;
+ }
+
+ /**
+ * Returns the next output of the Ramsete controller.
+ *
+ * <p>The reference pose, linear velocity, and angular velocity should come from a drivetrain
+ * trajectory.
+ *
+ * @param currentPose The current pose.
+ * @param poseRef The desired pose.
+ * @param linearVelocityRefMeters The desired linear velocity in meters per second.
+ * @param angularVelocityRefRadiansPerSecond The desired angular velocity in radians per second.
+ * @return The next controller output.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public ChassisSpeeds calculate(
+ Pose2d currentPose,
+ Pose2d poseRef,
+ double linearVelocityRefMeters,
+ double angularVelocityRefRadiansPerSecond) {
+ if (!m_enabled) {
+ return new ChassisSpeeds(linearVelocityRefMeters, 0.0, angularVelocityRefRadiansPerSecond);
+ }
+
+ m_poseError = poseRef.relativeTo(currentPose);
+
+ // Aliases for equation readability
+ final double eX = m_poseError.getX();
+ final double eY = m_poseError.getY();
+ final double eTheta = m_poseError.getRotation().getRadians();
+ final double vRef = linearVelocityRefMeters;
+ final double omegaRef = angularVelocityRefRadiansPerSecond;
+
+ double k = 2.0 * m_zeta * Math.sqrt(Math.pow(omegaRef, 2) + m_b * Math.pow(vRef, 2));
+
+ return new ChassisSpeeds(
+ vRef * m_poseError.getRotation().getCos() + k * eX,
+ 0.0,
+ omegaRef + k * eTheta + m_b * vRef * sinc(eTheta) * eY);
+ }
+
+ /**
+ * Returns the next output of the Ramsete controller.
+ *
+ * <p>The reference pose, linear velocity, and angular velocity should come from a drivetrain
+ * trajectory.
+ *
+ * @param currentPose The current pose.
+ * @param desiredState The desired pose, linear velocity, and angular velocity from a trajectory.
+ * @return The next controller output.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public ChassisSpeeds calculate(Pose2d currentPose, Trajectory.State desiredState) {
+ return calculate(
+ currentPose,
+ desiredState.poseMeters,
+ desiredState.velocityMetersPerSecond,
+ desiredState.velocityMetersPerSecond * desiredState.curvatureRadPerMeter);
+ }
+
+ /**
+ * Enables and disables the controller for troubleshooting purposes.
+ *
+ * @param enabled If the controller is enabled or not.
+ */
+ public void setEnabled(boolean enabled) {
+ m_enabled = enabled;
+ }
+
+ /**
+ * Returns sin(x) / x.
+ *
+ * @param x Value of which to take sinc(x).
+ */
+ @SuppressWarnings("ParameterName")
+ private static double sinc(double x) {
+ if (Math.abs(x) < 1e-9) {
+ return 1.0 - 1.0 / 6.0 * x * x;
+ } else {
+ return Math.sin(x) / x;
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/controller/SimpleMotorFeedforward.java b/wpimath/src/main/java/edu/wpi/first/math/controller/SimpleMotorFeedforward.java
new file mode 100644
index 0000000..f985960
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/controller/SimpleMotorFeedforward.java
@@ -0,0 +1,143 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.system.plant.LinearSystemId;
+
+/** A helper class that computes feedforward outputs for a simple permanent-magnet DC motor. */
+@SuppressWarnings("MemberName")
+public class SimpleMotorFeedforward {
+ public final double ks;
+ public final double kv;
+ public final double ka;
+
+ /**
+ * Creates a new SimpleMotorFeedforward with the specified gains. Units of the gain values will
+ * dictate units of the computed feedforward.
+ *
+ * @param ks The static gain.
+ * @param kv The velocity gain.
+ * @param ka The acceleration gain.
+ */
+ public SimpleMotorFeedforward(double ks, double kv, double ka) {
+ this.ks = ks;
+ this.kv = kv;
+ this.ka = ka;
+ }
+
+ /**
+ * Creates a new SimpleMotorFeedforward with the specified gains. Acceleration gain is defaulted
+ * to zero. Units of the gain values will dictate units of the computed feedforward.
+ *
+ * @param ks The static gain.
+ * @param kv The velocity gain.
+ */
+ public SimpleMotorFeedforward(double ks, double kv) {
+ this(ks, kv, 0);
+ }
+
+ /**
+ * Calculates the feedforward from the gains and setpoints.
+ *
+ * @param velocity The velocity setpoint.
+ * @param acceleration The acceleration setpoint.
+ * @return The computed feedforward.
+ */
+ public double calculate(double velocity, double acceleration) {
+ return ks * Math.signum(velocity) + kv * velocity + ka * acceleration;
+ }
+
+ /**
+ * Calculates the feedforward from the gains and setpoints.
+ *
+ * @param currentVelocity The current velocity setpoint.
+ * @param nextVelocity The next velocity setpoint.
+ * @param dtSeconds Time between velocity setpoints in seconds.
+ * @return The computed feedforward.
+ */
+ public double calculate(double currentVelocity, double nextVelocity, double dtSeconds) {
+ var plant = LinearSystemId.identifyVelocitySystem(this.kv, this.ka);
+ var feedforward = new LinearPlantInversionFeedforward<>(plant, dtSeconds);
+
+ var r = Matrix.mat(Nat.N1(), Nat.N1()).fill(currentVelocity);
+ var nextR = Matrix.mat(Nat.N1(), Nat.N1()).fill(nextVelocity);
+
+ return ks * Math.signum(currentVelocity) + feedforward.calculate(r, nextR).get(0, 0);
+ }
+
+ // Rearranging the main equation from the calculate() method yields the
+ // formulas for the methods below:
+
+ /**
+ * Calculates the feedforward from the gains and velocity setpoint (acceleration is assumed to be
+ * zero).
+ *
+ * @param velocity The velocity setpoint.
+ * @return The computed feedforward.
+ */
+ public double calculate(double velocity) {
+ return calculate(velocity, 0);
+ }
+
+ /**
+ * Calculates the maximum achievable velocity given a maximum voltage supply and an acceleration.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the acceleration constraint, and this will give you a
+ * simultaneously-achievable velocity constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the motor.
+ * @param acceleration The acceleration of the motor.
+ * @return The maximum possible velocity at the given acceleration.
+ */
+ public double maxAchievableVelocity(double maxVoltage, double acceleration) {
+ // Assume max velocity is positive
+ return (maxVoltage - ks - acceleration * ka) / kv;
+ }
+
+ /**
+ * Calculates the minimum achievable velocity given a maximum voltage supply and an acceleration.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the acceleration constraint, and this will give you a
+ * simultaneously-achievable velocity constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the motor.
+ * @param acceleration The acceleration of the motor.
+ * @return The minimum possible velocity at the given acceleration.
+ */
+ public double minAchievableVelocity(double maxVoltage, double acceleration) {
+ // Assume min velocity is negative, ks flips sign
+ return (-maxVoltage + ks - acceleration * ka) / kv;
+ }
+
+ /**
+ * Calculates the maximum achievable acceleration given a maximum voltage supply and a velocity.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the velocity constraint, and this will give you a
+ * simultaneously-achievable acceleration constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the motor.
+ * @param velocity The velocity of the motor.
+ * @return The maximum possible acceleration at the given velocity.
+ */
+ public double maxAchievableAcceleration(double maxVoltage, double velocity) {
+ return (maxVoltage - ks * Math.signum(velocity) - velocity * kv) / ka;
+ }
+
+ /**
+ * Calculates the maximum achievable acceleration given a maximum voltage supply and a velocity.
+ * Useful for ensuring that velocity and acceleration constraints for a trapezoidal profile are
+ * simultaneously achievable - enter the velocity constraint, and this will give you a
+ * simultaneously-achievable acceleration constraint.
+ *
+ * @param maxVoltage The maximum voltage that can be supplied to the motor.
+ * @param velocity The velocity of the motor.
+ * @return The minimum possible acceleration at the given velocity.
+ */
+ public double minAchievableAcceleration(double maxVoltage, double velocity) {
+ return maxAchievableAcceleration(-maxVoltage, velocity);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/AngleStatistics.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/AngleStatistics.java
new file mode 100644
index 0000000..a6d8b44
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/AngleStatistics.java
@@ -0,0 +1,123 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.MathUtil;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+import java.util.function.BiFunction;
+import org.ejml.simple.SimpleMatrix;
+
+public final class AngleStatistics {
+ private AngleStatistics() {
+ // Utility class
+ }
+
+ /**
+ * Subtracts a and b while normalizing the resulting value in the selected row as if it were an
+ * angle.
+ *
+ * @param <S> Number of rows in vector.
+ * @param a A vector to subtract from.
+ * @param b A vector to subtract with.
+ * @param angleStateIdx The row containing angles to be normalized.
+ * @return Difference of two vectors with angle at the given index normalized.
+ */
+ public static <S extends Num> Matrix<S, N1> angleResidual(
+ Matrix<S, N1> a, Matrix<S, N1> b, int angleStateIdx) {
+ Matrix<S, N1> ret = a.minus(b);
+ ret.set(angleStateIdx, 0, MathUtil.angleModulus(ret.get(angleStateIdx, 0)));
+
+ return ret;
+ }
+
+ /**
+ * Returns a function that subtracts two vectors while normalizing the resulting value in the
+ * selected row as if it were an angle.
+ *
+ * @param <S> Number of rows in vector.
+ * @param angleStateIdx The row containing angles to be normalized.
+ * @return Function returning difference of two vectors with angle at the given index normalized.
+ */
+ public static <S extends Num>
+ BiFunction<Matrix<S, N1>, Matrix<S, N1>, Matrix<S, N1>> angleResidual(int angleStateIdx) {
+ return (a, b) -> angleResidual(a, b, angleStateIdx);
+ }
+
+ /**
+ * Adds a and b while normalizing the resulting value in the selected row as an angle.
+ *
+ * @param <S> Number of rows in vector.
+ * @param a A vector to add with.
+ * @param b A vector to add with.
+ * @param angleStateIdx The row containing angles to be normalized.
+ * @return Sum of two vectors with angle at the given index normalized.
+ */
+ public static <S extends Num> Matrix<S, N1> angleAdd(
+ Matrix<S, N1> a, Matrix<S, N1> b, int angleStateIdx) {
+ Matrix<S, N1> ret = a.plus(b);
+ ret.set(angleStateIdx, 0, MathUtil.angleModulus(ret.get(angleStateIdx, 0)));
+
+ return ret;
+ }
+
+ /**
+ * Returns a function that adds two vectors while normalizing the resulting value in the selected
+ * row as an angle.
+ *
+ * @param <S> Number of rows in vector.
+ * @param angleStateIdx The row containing angles to be normalized.
+ * @return Function returning of two vectors with angle at the given index normalized.
+ */
+ public static <S extends Num> BiFunction<Matrix<S, N1>, Matrix<S, N1>, Matrix<S, N1>> angleAdd(
+ int angleStateIdx) {
+ return (a, b) -> angleAdd(a, b, angleStateIdx);
+ }
+
+ /**
+ * Computes the mean of sigmas with the weights Wm while computing a special angle mean for a
+ * select row.
+ *
+ * @param <S> Number of rows in sigma point matrix.
+ * @param sigmas Sigma points.
+ * @param Wm Weights for the mean.
+ * @param angleStateIdx The row containing the angles.
+ * @return Mean of sigma points.
+ */
+ @SuppressWarnings("checkstyle:ParameterName")
+ public static <S extends Num> Matrix<S, N1> angleMean(
+ Matrix<S, ?> sigmas, Matrix<?, N1> Wm, int angleStateIdx) {
+ double[] angleSigmas = sigmas.extractRowVector(angleStateIdx).getData();
+ Matrix<N1, ?> sinAngleSigmas = new Matrix<>(new SimpleMatrix(1, sigmas.getNumCols()));
+ Matrix<N1, ?> cosAngleSigmas = new Matrix<>(new SimpleMatrix(1, sigmas.getNumCols()));
+ for (int i = 0; i < angleSigmas.length; i++) {
+ sinAngleSigmas.set(0, i, Math.sin(angleSigmas[i]));
+ cosAngleSigmas.set(0, i, Math.cos(angleSigmas[i]));
+ }
+
+ double sumSin = sinAngleSigmas.times(Matrix.changeBoundsUnchecked(Wm)).elementSum();
+ double sumCos = cosAngleSigmas.times(Matrix.changeBoundsUnchecked(Wm)).elementSum();
+
+ Matrix<S, N1> ret = sigmas.times(Matrix.changeBoundsUnchecked(Wm));
+ ret.set(angleStateIdx, 0, Math.atan2(sumSin, sumCos));
+
+ return ret;
+ }
+
+ /**
+ * Returns a function that computes the mean of sigmas with the weights Wm while computing a
+ * special angle mean for a select row.
+ *
+ * @param <S> Number of rows in sigma point matrix.
+ * @param angleStateIdx The row containing the angles.
+ * @return Function returning mean of sigma points.
+ */
+ @SuppressWarnings("LambdaParameterName")
+ public static <S extends Num> BiFunction<Matrix<S, ?>, Matrix<?, N1>, Matrix<S, N1>> angleMean(
+ int angleStateIdx) {
+ return (sigmas, Wm) -> angleMean(sigmas, Wm, angleStateIdx);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/DifferentialDrivePoseEstimator.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/DifferentialDrivePoseEstimator.java
new file mode 100644
index 0000000..074362e
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/DifferentialDrivePoseEstimator.java
@@ -0,0 +1,372 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.MatBuilder;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.kinematics.DifferentialDriveWheelSpeeds;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.math.numbers.N5;
+import edu.wpi.first.util.WPIUtilJNI;
+import java.util.function.BiConsumer;
+
+/**
+ * This class wraps an {@link edu.wpi.first.math.estimator.UnscentedKalmanFilter Unscented Kalman
+ * Filter} to fuse latency-compensated vision measurements with differential drive encoder
+ * measurements. It will correct for noisy vision measurements and encoder drift. It is intended to
+ * be an easy drop-in for {@link edu.wpi.first.math.kinematics.DifferentialDriveOdometry}; in fact,
+ * if you never call {@link DifferentialDrivePoseEstimator#addVisionMeasurement} and only call
+ * {@link DifferentialDrivePoseEstimator#update} then this will behave exactly the same as
+ * DifferentialDriveOdometry.
+ *
+ * <p>{@link DifferentialDrivePoseEstimator#update} should be called every robot loop (if your robot
+ * loops are faster than the default then you should change the {@link
+ * DifferentialDrivePoseEstimator#DifferentialDrivePoseEstimator(Rotation2d, Pose2d, Matrix, Matrix,
+ * Matrix, double) nominal delta time}.) {@link DifferentialDrivePoseEstimator#addVisionMeasurement}
+ * can be called as infrequently as you want; if you never call it then this class will behave
+ * exactly like regular encoder odometry.
+ *
+ * <p>The state-space system used internally has the following states (x), inputs (u), and outputs
+ * (y):
+ *
+ * <p><strong> x = [x, y, theta, dist_l, dist_r]ᵀ </strong> in the field coordinate system
+ * containing x position, y position, heading, left encoder distance, and right encoder distance.
+ *
+ * <p><strong> u = [v_l, v_r, dtheta]ᵀ </strong> containing left wheel velocity, right wheel
+ * velocity, and change in gyro heading.
+ *
+ * <p>NB: Using velocities make things considerably easier, because it means that teams don't have
+ * to worry about getting an accurate model. Basically, we suspect that it's easier for teams to get
+ * good encoder data than it is for them to perform system identification well enough to get a good
+ * model.
+ *
+ * <p><strong> y = [x, y, theta]ᵀ </strong> from vision containing x position, y position, and
+ * heading; or <strong>y = [dist_l, dist_r, theta] </strong> containing left encoder position, right
+ * encoder position, and gyro heading.
+ */
+public class DifferentialDrivePoseEstimator {
+ final UnscentedKalmanFilter<N5, N3, N3> m_observer; // Package-private to allow for unit testing
+ private final BiConsumer<Matrix<N3, N1>, Matrix<N3, N1>> m_visionCorrect;
+ private final KalmanFilterLatencyCompensator<N5, N3, N3> m_latencyCompensator;
+
+ private final double m_nominalDt; // Seconds
+ private double m_prevTimeSeconds = -1.0;
+
+ private Rotation2d m_gyroOffset;
+ private Rotation2d m_previousAngle;
+
+ private Matrix<N3, N3> m_visionContR;
+
+ /**
+ * Constructs a DifferentialDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPoseMeters The starting pose estimate.
+ * @param stateStdDevs Standard deviations of model states. Increase these numbers to trust your
+ * model's state estimates less. This matrix is in the form [x, y, theta, dist_l, dist_r]ᵀ,
+ * with units in meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
+ * Increase these numbers to trust sensor readings from encoders and gyros less. This matrix
+ * is in the form [dist_l, dist_r, theta]ᵀ, with units in meters and radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public DifferentialDrivePoseEstimator(
+ Rotation2d gyroAngle,
+ Pose2d initialPoseMeters,
+ Matrix<N5, N1> stateStdDevs,
+ Matrix<N3, N1> localMeasurementStdDevs,
+ Matrix<N3, N1> visionMeasurementStdDevs) {
+ this(
+ gyroAngle,
+ initialPoseMeters,
+ stateStdDevs,
+ localMeasurementStdDevs,
+ visionMeasurementStdDevs,
+ 0.02);
+ }
+
+ /**
+ * Constructs a DifferentialDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPoseMeters The starting pose estimate.
+ * @param stateStdDevs Standard deviations of model states. Increase these numbers to trust your
+ * model's state estimates less. This matrix is in the form [x, y, theta, dist_l, dist_r]ᵀ,
+ * with units in meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
+ * Increase these numbers to trust sensor readings from encoders and gyros less. This matrix
+ * is in the form [dist_l, dist_r, theta]ᵀ, with units in meters and radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ * @param nominalDtSeconds The time in seconds between each robot loop.
+ */
+ @SuppressWarnings("ParameterName")
+ public DifferentialDrivePoseEstimator(
+ Rotation2d gyroAngle,
+ Pose2d initialPoseMeters,
+ Matrix<N5, N1> stateStdDevs,
+ Matrix<N3, N1> localMeasurementStdDevs,
+ Matrix<N3, N1> visionMeasurementStdDevs,
+ double nominalDtSeconds) {
+ m_nominalDt = nominalDtSeconds;
+
+ m_observer =
+ new UnscentedKalmanFilter<>(
+ Nat.N5(),
+ Nat.N3(),
+ this::f,
+ (x, u) -> VecBuilder.fill(x.get(3, 0), x.get(4, 0), x.get(2, 0)),
+ stateStdDevs,
+ localMeasurementStdDevs,
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleAdd(2),
+ m_nominalDt);
+ m_latencyCompensator = new KalmanFilterLatencyCompensator<>();
+
+ // Initialize vision R
+ setVisionMeasurementStdDevs(visionMeasurementStdDevs);
+
+ m_visionCorrect =
+ (u, y) ->
+ m_observer.correct(
+ Nat.N3(),
+ u,
+ y,
+ (x, u1) -> new Matrix<>(x.getStorage().extractMatrix(0, 3, 0, 1)),
+ m_visionContR,
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleAdd(2));
+
+ m_gyroOffset = initialPoseMeters.getRotation().minus(gyroAngle);
+ m_previousAngle = initialPoseMeters.getRotation();
+ m_observer.setXhat(fillStateVector(initialPoseMeters, 0.0, 0.0));
+ }
+
+ /**
+ * Sets the pose estimator's trust of global measurements. This might be used to change trust in
+ * vision measurements after the autonomous period, or to change trust as distance to a vision
+ * target increases.
+ *
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public void setVisionMeasurementStdDevs(Matrix<N3, N1> visionMeasurementStdDevs) {
+ m_visionContR = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), visionMeasurementStdDevs);
+ }
+
+ @SuppressWarnings({"ParameterName", "MethodName"})
+ private Matrix<N5, N1> f(Matrix<N5, N1> x, Matrix<N3, N1> u) {
+ // Apply a rotation matrix. Note that we do *not* add x--Runge-Kutta does that for us.
+ var theta = x.get(2, 0);
+ var toFieldRotation =
+ new MatBuilder<>(Nat.N5(), Nat.N5())
+ .fill(
+ Math.cos(theta),
+ -Math.sin(theta),
+ 0,
+ 0,
+ 0,
+ Math.sin(theta),
+ Math.cos(theta),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1);
+ return toFieldRotation.times(
+ VecBuilder.fill(u.get(0, 0), u.get(1, 0), u.get(2, 0), u.get(0, 0), u.get(1, 0)));
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>You NEED to reset your encoders (to zero) when calling this method.
+ *
+ * <p>The gyroscope angle does not need to be reset here on the user's robot code. The library
+ * automatically takes care of offsetting the gyro angle.
+ *
+ * @param poseMeters The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
+ // Reset state estimate and error covariance
+ m_observer.reset();
+ m_latencyCompensator.reset();
+
+ m_observer.setXhat(fillStateVector(poseMeters, 0.0, 0.0));
+
+ m_gyroOffset = getEstimatedPosition().getRotation().minus(gyroAngle);
+ m_previousAngle = poseMeters.getRotation();
+ }
+
+ /**
+ * Gets the pose of the robot at the current time as estimated by the Unscented Kalman Filter.
+ *
+ * @return The estimated robot pose in meters.
+ */
+ public Pose2d getEstimatedPosition() {
+ return new Pose2d(
+ m_observer.getXhat(0), m_observer.getXhat(1), new Rotation2d(m_observer.getXhat(2)));
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct the odometry pose
+ * estimate while still accounting for measurement noise.
+ *
+ * <p>This method can be called as infrequently as you want, as long as you are calling {@link
+ * DifferentialDrivePoseEstimator#update} every loop.
+ *
+ * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
+ * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
+ * don't use your own time source by calling {@link
+ * DifferentialDrivePoseEstimator#updateWithTime} then you must use a timestamp with an epoch
+ * since FPGA startup (i.e. the epoch of this timestamp is the same epoch as
+ * Timer.getFPGATimestamp.) This means that you should use Timer.getFPGATimestamp as your time
+ * source in this case.
+ */
+ public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
+ m_latencyCompensator.applyPastGlobalMeasurement(
+ Nat.N3(),
+ m_observer,
+ m_nominalDt,
+ StateSpaceUtil.poseTo3dVector(visionRobotPoseMeters),
+ m_visionCorrect,
+ timestampSeconds);
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct the odometry pose
+ * estimate while still accounting for measurement noise.
+ *
+ * <p>This method can be called as infrequently as you want, as long as you are calling {@link
+ * DifferentialDrivePoseEstimator#update} every loop.
+ *
+ * <p>Note that the vision measurement standard deviations passed into this method will continue
+ * to apply to future measurements until a subsequent call to {@link
+ * DifferentialDrivePoseEstimator#setVisionMeasurementStdDevs(Matrix)} or this method.
+ *
+ * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
+ * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
+ * don't use your own time source by calling {@link
+ * DifferentialDrivePoseEstimator#updateWithTime} then you must use a timestamp with an epoch
+ * since FPGA startup (i.e. the epoch of this timestamp is the same epoch as
+ * Timer.getFPGATimestamp.) This means that you should use Timer.getFPGATimestamp as your time
+ * source in this case.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public void addVisionMeasurement(
+ Pose2d visionRobotPoseMeters,
+ double timestampSeconds,
+ Matrix<N3, N1> visionMeasurementStdDevs) {
+ setVisionMeasurementStdDevs(visionMeasurementStdDevs);
+ addVisionMeasurement(visionRobotPoseMeters, timestampSeconds);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder information. Note that this
+ * should be called every loop.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param wheelVelocitiesMetersPerSecond The velocities of the wheels in meters per second.
+ * @param distanceLeftMeters The total distance travelled by the left wheel in meters since the
+ * last time you called {@link DifferentialDrivePoseEstimator#resetPosition}.
+ * @param distanceRightMeters The total distance travelled by the right wheel in meters since the
+ * last time you called {@link DifferentialDrivePoseEstimator#resetPosition}.
+ * @return The estimated pose of the robot in meters.
+ */
+ public Pose2d update(
+ Rotation2d gyroAngle,
+ DifferentialDriveWheelSpeeds wheelVelocitiesMetersPerSecond,
+ double distanceLeftMeters,
+ double distanceRightMeters) {
+ return updateWithTime(
+ WPIUtilJNI.now() * 1.0e-6,
+ gyroAngle,
+ wheelVelocitiesMetersPerSecond,
+ distanceLeftMeters,
+ distanceRightMeters);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder information. Note that this
+ * should be called every loop.
+ *
+ * @param currentTimeSeconds Time at which this method was called, in seconds.
+ * @param gyroAngle The current gyro angle.
+ * @param wheelVelocitiesMetersPerSecond The velocities of the wheels in meters per second.
+ * @param distanceLeftMeters The total distance travelled by the left wheel in meters since the
+ * last time you called {@link DifferentialDrivePoseEstimator#resetPosition}.
+ * @param distanceRightMeters The total distance travelled by the right wheel in meters since the
+ * last time you called {@link DifferentialDrivePoseEstimator#resetPosition}.
+ * @return The estimated pose of the robot in meters.
+ */
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ public Pose2d updateWithTime(
+ double currentTimeSeconds,
+ Rotation2d gyroAngle,
+ DifferentialDriveWheelSpeeds wheelVelocitiesMetersPerSecond,
+ double distanceLeftMeters,
+ double distanceRightMeters) {
+ double dt = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : m_nominalDt;
+ m_prevTimeSeconds = currentTimeSeconds;
+
+ var angle = gyroAngle.plus(m_gyroOffset);
+ // Diff drive forward kinematics:
+ // v_c = (v_l + v_r) / 2
+ var wheelVels = wheelVelocitiesMetersPerSecond;
+ var u =
+ VecBuilder.fill(
+ (wheelVels.leftMetersPerSecond + wheelVels.rightMetersPerSecond) / 2,
+ 0,
+ angle.minus(m_previousAngle).getRadians() / dt);
+ m_previousAngle = angle;
+
+ var localY = VecBuilder.fill(distanceLeftMeters, distanceRightMeters, angle.getRadians());
+ m_latencyCompensator.addObserverState(m_observer, u, localY, currentTimeSeconds);
+ m_observer.predict(u, dt);
+ m_observer.correct(u, localY);
+
+ return getEstimatedPosition();
+ }
+
+ private static Matrix<N5, N1> fillStateVector(Pose2d pose, double leftDist, double rightDist) {
+ return VecBuilder.fill(
+ pose.getTranslation().getX(),
+ pose.getTranslation().getY(),
+ pose.getRotation().getRadians(),
+ leftDist,
+ rightDist);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/ExtendedKalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/ExtendedKalmanFilter.java
new file mode 100644
index 0000000..d4f9e56
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/ExtendedKalmanFilter.java
@@ -0,0 +1,372 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.Drake;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.system.Discretization;
+import edu.wpi.first.math.system.NumericalIntegration;
+import edu.wpi.first.math.system.NumericalJacobian;
+import java.util.function.BiFunction;
+
+/**
+ * A Kalman filter combines predictions from a model and measurements to give an estimate of the
+ * true system state. This is useful because many states cannot be measured directly as a result of
+ * sensor noise, or because the state is "hidden".
+ *
+ * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
+ * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
+ * of squares error in the state estimate. This K gain is used to correct the state estimate by some
+ * amount of the difference between the actual measurements and the measurements predicted by the
+ * model.
+ *
+ * <p>An extended Kalman filter supports nonlinear state and measurement models. It propagates the
+ * error covariance by linearizing the models around the state estimate, then applying the linear
+ * Kalman filter equations.
+ *
+ * <p>For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
+ * theory".
+ */
+@SuppressWarnings("ClassTypeParameterName")
+public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
+ implements KalmanTypeFilter<States, Inputs, Outputs> {
+ private final Nat<States> m_states;
+ private final Nat<Outputs> m_outputs;
+
+ @SuppressWarnings("MemberName")
+ private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
+
+ @SuppressWarnings("MemberName")
+ private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
+
+ private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
+ private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
+
+ private final Matrix<States, States> m_contQ;
+ private final Matrix<States, States> m_initP;
+ private final Matrix<Outputs, Outputs> m_contR;
+
+ @SuppressWarnings("MemberName")
+ private Matrix<States, N1> m_xHat;
+
+ @SuppressWarnings("MemberName")
+ private Matrix<States, States> m_P;
+
+ private double m_dtSeconds;
+
+ /**
+ * Constructs an extended Kalman filter.
+ *
+ * @param states a Nat representing the number of states.
+ * @param inputs a Nat representing the number of inputs.
+ * @param outputs a Nat representing the number of outputs.
+ * @param f A vector-valued function of x and u that returns the derivative of the state vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param stateStdDevs Standard deviations of model states.
+ * @param measurementStdDevs Standard deviations of measurements.
+ * @param dtSeconds Nominal discretization timestep.
+ */
+ @SuppressWarnings("ParameterName")
+ public ExtendedKalmanFilter(
+ Nat<States> states,
+ Nat<Inputs> inputs,
+ Nat<Outputs> outputs,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
+ Matrix<States, N1> stateStdDevs,
+ Matrix<Outputs, N1> measurementStdDevs,
+ double dtSeconds) {
+ this(
+ states,
+ inputs,
+ outputs,
+ f,
+ h,
+ stateStdDevs,
+ measurementStdDevs,
+ Matrix::minus,
+ Matrix::plus,
+ dtSeconds);
+ }
+
+ /**
+ * Constructs an extended Kalman filter.
+ *
+ * @param states a Nat representing the number of states.
+ * @param inputs a Nat representing the number of inputs.
+ * @param outputs a Nat representing the number of outputs.
+ * @param f A vector-valued function of x and u that returns the derivative of the state vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param stateStdDevs Standard deviations of model states.
+ * @param measurementStdDevs Standard deviations of measurements.
+ * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
+ * subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ * @param dtSeconds Nominal discretization timestep.
+ */
+ @SuppressWarnings("ParameterName")
+ public ExtendedKalmanFilter(
+ Nat<States> states,
+ Nat<Inputs> inputs,
+ Nat<Outputs> outputs,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
+ Matrix<States, N1> stateStdDevs,
+ Matrix<Outputs, N1> measurementStdDevs,
+ BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
+ double dtSeconds) {
+ m_states = states;
+ m_outputs = outputs;
+
+ m_f = f;
+ m_h = h;
+
+ m_residualFuncY = residualFuncY;
+ m_addFuncX = addFuncX;
+
+ m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
+ this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
+ m_dtSeconds = dtSeconds;
+
+ reset();
+
+ final var contA =
+ NumericalJacobian.numericalJacobianX(
+ states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1()));
+ final var C =
+ NumericalJacobian.numericalJacobianX(
+ outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1()));
+
+ final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
+ final var discA = discPair.getFirst();
+ final var discQ = discPair.getSecond();
+
+ final var discR = Discretization.discretizeR(m_contR, dtSeconds);
+
+ if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) {
+ m_initP =
+ Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR);
+ } else {
+ m_initP = new Matrix<>(states, states);
+ }
+
+ m_P = m_initP;
+ }
+
+ /**
+ * Returns the error covariance matrix P.
+ *
+ * @return the error covariance matrix P.
+ */
+ @Override
+ public Matrix<States, States> getP() {
+ return m_P;
+ }
+
+ /**
+ * Returns an element of the error covariance matrix P.
+ *
+ * @param row Row of P.
+ * @param col Column of P.
+ * @return the value of the error covariance matrix P at (i, j).
+ */
+ @Override
+ public double getP(int row, int col) {
+ return m_P.get(row, col);
+ }
+
+ /**
+ * Sets the entire error covariance matrix P.
+ *
+ * @param newP The new value of P to use.
+ */
+ @Override
+ public void setP(Matrix<States, States> newP) {
+ m_P = newP;
+ }
+
+ /**
+ * Returns the state estimate x-hat.
+ *
+ * @return the state estimate x-hat.
+ */
+ @Override
+ public Matrix<States, N1> getXhat() {
+ return m_xHat;
+ }
+
+ /**
+ * Returns an element of the state estimate x-hat.
+ *
+ * @param row Row of x-hat.
+ * @return the value of the state estimate x-hat at i.
+ */
+ @Override
+ public double getXhat(int row) {
+ return m_xHat.get(row, 0);
+ }
+
+ /**
+ * Set initial state estimate x-hat.
+ *
+ * @param xHat The state estimate x-hat.
+ */
+ @SuppressWarnings("ParameterName")
+ @Override
+ public void setXhat(Matrix<States, N1> xHat) {
+ m_xHat = xHat;
+ }
+
+ /**
+ * Set an element of the initial state estimate x-hat.
+ *
+ * @param row Row of x-hat.
+ * @param value Value for element of x-hat.
+ */
+ @Override
+ public void setXhat(int row, double value) {
+ m_xHat.set(row, 0, value);
+ }
+
+ @Override
+ public void reset() {
+ m_xHat = new Matrix<>(m_states, Nat.N1());
+ m_P = m_initP;
+ }
+
+ /**
+ * Project the model into the future with a new control input u.
+ *
+ * @param u New control input from controller.
+ * @param dtSeconds Timestep for prediction.
+ */
+ @SuppressWarnings("ParameterName")
+ @Override
+ public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
+ predict(u, m_f, dtSeconds);
+ }
+
+ /**
+ * Project the model into the future with a new control input u.
+ *
+ * @param u New control input from controller.
+ * @param f The function used to linearlize the model.
+ * @param dtSeconds Timestep for prediction.
+ */
+ @SuppressWarnings("ParameterName")
+ public void predict(
+ Matrix<Inputs, N1> u,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ double dtSeconds) {
+ // Find continuous A
+ final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u);
+
+ // Find discrete A and Q
+ final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
+ final var discA = discPair.getFirst();
+ final var discQ = discPair.getSecond();
+
+ m_xHat = NumericalIntegration.rk4(f, m_xHat, u, dtSeconds);
+
+ // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
+ m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
+
+ m_dtSeconds = dtSeconds;
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ */
+ @SuppressWarnings("ParameterName")
+ @Override
+ public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
+ correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
+ * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
+ * of this function).
+ *
+ * @param <Rows> Number of rows in the result of f(x, u).
+ * @param rows Number of rows in the result of f(x, u).
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param R Discrete measurement noise covariance matrix.
+ */
+ @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+ public <Rows extends Num> void correct(
+ Nat<Rows> rows,
+ Matrix<Inputs, N1> u,
+ Matrix<Rows, N1> y,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
+ Matrix<Rows, Rows> R) {
+ correct(rows, u, y, h, R, Matrix::minus, Matrix::plus);
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
+ * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
+ * of this function).
+ *
+ * @param <Rows> Number of rows in the result of f(x, u).
+ * @param rows Number of rows in the result of f(x, u).
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param R Discrete measurement noise covariance matrix.
+ * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
+ * subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ */
+ @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+ public <Rows extends Num> void correct(
+ Nat<Rows> rows,
+ Matrix<Inputs, N1> u,
+ Matrix<Rows, N1> y,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
+ Matrix<Rows, Rows> R,
+ BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY,
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
+ final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
+ final var discR = Discretization.discretizeR(R, m_dtSeconds);
+
+ final var S = C.times(m_P).times(C.transpose()).plus(discR);
+
+ // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
+ // efficiently.
+ //
+ // K = PCᵀS⁻¹
+ // KS = PCᵀ
+ // (KS)ᵀ = (PCᵀ)ᵀ
+ // SᵀKᵀ = CPᵀ
+ //
+ // The solution of Ax = b can be found via x = A.solve(b).
+ //
+ // Kᵀ = Sᵀ.solve(CPᵀ)
+ // K = (Sᵀ.solve(CPᵀ))ᵀ
+ final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
+
+ // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁))
+ m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u))));
+
+ // Pₖ₊₁⁺ = (I − KC)Pₖ₊₁⁻
+ m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanFilter.java
similarity index 62%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java
rename to wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanFilter.java
index 99fa2b7..4aea69d 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanFilter.java
@@ -1,21 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.estimator;
+package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Drake;
import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.system.Discretization;
+import edu.wpi.first.math.system.LinearSystem;
/**
* A Kalman filter combines predictions from a model and measurements to give an estimate of the
@@ -24,51 +21,46 @@
*
* <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
* more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
- * of squares error in the state estimate. This K gain is used to correct the state estimate by
- * some amount of the difference between the actual measurements and the measurements predicted by
- * the model.
+ * of squares error in the state estimate. This K gain is used to correct the state estimate by some
+ * amount of the difference between the actual measurements and the measurements predicted by the
+ * model.
*
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
* theory".
*/
@SuppressWarnings("ClassTypeParameterName")
-public class KalmanFilter<States extends Num, Inputs extends Num,
- Outputs extends Num> {
+public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num> {
private final Nat<States> m_states;
private final LinearSystem<States, Inputs, Outputs> m_plant;
- /**
- * The steady-state Kalman gain matrix.
- */
+ /** The steady-state Kalman gain matrix. */
@SuppressWarnings("MemberName")
private final Matrix<States, Outputs> m_K;
- /**
- * The state estimate.
- */
+ /** The state estimate. */
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_xHat;
/**
* Constructs a state-space observer with the given plant.
*
- * @param states A Nat representing the states of the system.
- * @param outputs A Nat representing the outputs of the system.
- * @param plant The plant used for the prediction step.
- * @param stateStdDevs Standard deviations of model states.
+ * @param states A Nat representing the states of the system.
+ * @param outputs A Nat representing the outputs of the system.
+ * @param plant The plant used for the prediction step.
+ * @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
- * @param dtSeconds Nominal discretization timestep.
+ * @param dtSeconds Nominal discretization timestep.
*/
@SuppressWarnings("LocalVariableName")
public KalmanFilter(
- Nat<States> states, Nat<Outputs> outputs,
+ Nat<States> states,
+ Nat<Outputs> outputs,
LinearSystem<States, Inputs, Outputs> plant,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
- double dtSeconds
- ) {
+ double dtSeconds) {
this.m_states = states;
this.m_plant = plant;
@@ -84,34 +76,41 @@
var C = plant.getC();
- // isStabilizable(A^T, C^T) will tell us if the system is observable.
- var isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), C.transpose());
- if (!isObservable) {
- MathSharedStore.reportError("The system passed to the Kalman filter is not observable!",
- Thread.currentThread().getStackTrace());
- throw new IllegalArgumentException(
- "The system passed to the Kalman filter is not observable!");
+ if (!StateSpaceUtil.isDetectable(discA, C)) {
+ var builder =
+ new StringBuilder("The system passed to the Kalman filter is unobservable!\n\nA =\n");
+ builder.append(discA.getStorage().toString());
+ builder.append("\nC =\n");
+ builder.append(C.getStorage().toString());
+ builder.append("\n");
+
+ var msg = builder.toString();
+ MathSharedStore.reportError(msg, Thread.currentThread().getStackTrace());
+ throw new IllegalArgumentException(msg);
}
- var P = new Matrix<>(Drake.discreteAlgebraicRiccatiEquation(
- discA.transpose(), C.transpose(), discQ, discR));
+ var P =
+ new Matrix<>(
+ Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR));
+ // S = CPCᵀ + R
var S = C.times(P).times(C.transpose()).plus(discR);
- // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
+ // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
// efficiently.
//
- // K = PC^T S^-1
- // KS = PC^T
- // (KS)^T = (PC^T)^T
- // S^T K^T = CP^T
+ // K = PCᵀS⁻¹
+ // KS = PCᵀ
+ // (KS)ᵀ = (PCᵀ)ᵀ
+ // SᵀKᵀ = CPᵀ
//
// The solution of Ax = b can be found via x = A.solve(b).
//
- // K^T = S^T.solve(CP^T)
- // K = (S^T.solve(CP^T))^T
- m_K = new Matrix<>(S.transpose().getStorage()
- .solve((C.times(P.transpose())).getStorage()).transpose());
+ // Kᵀ = Sᵀ.solve(CPᵀ)
+ // K = (Sᵀ.solve(CPᵀ))ᵀ
+ m_K =
+ new Matrix<>(
+ S.transpose().getStorage().solve((C.times(P.transpose())).getStorage()).transpose());
reset();
}
@@ -152,7 +151,7 @@
/**
* Set an element of the initial state estimate x-hat.
*
- * @param row Row of x-hat.
+ * @param row Row of x-hat.
* @param value Value for element of x-hat.
*/
public void setXhat(int row, double value) {
@@ -181,7 +180,7 @@
/**
* Project the model into the future with a new control input u.
*
- * @param u New control input from controller.
+ * @param u New control input from controller.
* @param dtSeconds Timestep for prediction.
*/
@SuppressWarnings("ParameterName")
@@ -199,6 +198,7 @@
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
final var C = m_plant.getC();
final var D = m_plant.getD();
+ // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁))
m_xHat = m_xHat.plus(m_K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
}
}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanFilterLatencyCompensator.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanFilterLatencyCompensator.java
new file mode 100644
index 0000000..50a83b5
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanFilterLatencyCompensator.java
@@ -0,0 +1,161 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiConsumer;
+
+public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> {
+ private static final int kMaxPastObserverStates = 300;
+
+ private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots;
+
+ KalmanFilterLatencyCompensator() {
+ m_pastObserverSnapshots = new ArrayList<>();
+ }
+
+ /** Clears the observer snapshot buffer. */
+ public void reset() {
+ m_pastObserverSnapshots.clear();
+ }
+
+ /**
+ * Add past observer states to the observer snapshots list.
+ *
+ * @param observer The observer.
+ * @param u The input at the timestamp.
+ * @param localY The local output at the timestamp
+ * @param timestampSeconds The timesnap of the state.
+ */
+ @SuppressWarnings("ParameterName")
+ public void addObserverState(
+ KalmanTypeFilter<S, I, O> observer,
+ Matrix<I, N1> u,
+ Matrix<O, N1> localY,
+ double timestampSeconds) {
+ m_pastObserverSnapshots.add(
+ Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY)));
+
+ if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
+ m_pastObserverSnapshots.remove(0);
+ }
+ }
+
+ /**
+ * Add past global measurements (such as from vision)to the estimator.
+ *
+ * @param <R> The rows in the global measurement vector.
+ * @param rows The rows in the global measurement vector.
+ * @param observer The observer to apply the past global measurement.
+ * @param nominalDtSeconds The nominal timestep.
+ * @param y The measurement.
+ * @param globalMeasurementCorrect The function take calls correct() on the observer.
+ * @param timestampSeconds The timestamp of the measurement.
+ */
+ @SuppressWarnings("ParameterName")
+ public <R extends Num> void applyPastGlobalMeasurement(
+ Nat<R> rows,
+ KalmanTypeFilter<S, I, O> observer,
+ double nominalDtSeconds,
+ Matrix<R, N1> y,
+ BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect,
+ double timestampSeconds) {
+ if (m_pastObserverSnapshots.isEmpty()) {
+ // State map was empty, which means that we got a past measurement right at startup. The only
+ // thing we can really do is ignore the measurement.
+ return;
+ }
+
+ // This index starts at one because we use the previous state later on, and we always want to
+ // have a "previous state".
+ int maxIdx = m_pastObserverSnapshots.size() - 1;
+ int low = 1;
+ int high = Math.max(maxIdx, 1);
+
+ while (low != high) {
+ int mid = (low + high) / 2;
+ if (m_pastObserverSnapshots.get(mid).getKey() < timestampSeconds) {
+ // This index and everything under it are less than the requested timestamp. Therefore, we
+ // can discard them.
+ low = mid + 1;
+ } else {
+ // t is at least as large as the element at this index. This means that anything after it
+ // cannot be what we are looking for.
+ high = mid;
+ }
+ }
+
+ // We are simply assigning this index to a new variable to avoid confusion
+ // with variable names.
+ int index = low;
+ double timestamp = timestampSeconds;
+ int indexOfClosestEntry =
+ Math.abs(timestamp - m_pastObserverSnapshots.get(index - 1).getKey())
+ <= Math.abs(
+ timestamp - m_pastObserverSnapshots.get(Math.min(index, maxIdx)).getKey())
+ ? index - 1
+ : index;
+
+ double lastTimestamp =
+ m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds;
+
+ // We will now go back in time to the state of the system at the time when
+ // the measurement was captured. We will reset the observer to that state,
+ // and apply correction based on the measurement. Then, we will go back
+ // through all observer states until the present and apply past inputs to
+ // get the present estimated state.
+ for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) {
+ var key = m_pastObserverSnapshots.get(i).getKey();
+ var snapshot = m_pastObserverSnapshots.get(i).getValue();
+
+ if (i == indexOfClosestEntry) {
+ observer.setP(snapshot.errorCovariances);
+ observer.setXhat(snapshot.xHat);
+ }
+
+ observer.predict(snapshot.inputs, key - lastTimestamp);
+ observer.correct(snapshot.inputs, snapshot.localMeasurements);
+
+ if (i == indexOfClosestEntry) {
+ // Note that the measurement is at a timestep close but probably not exactly equal to the
+ // timestep for which we called predict.
+ // This makes the assumption that the dt is small enough that the difference between the
+ // measurement time and the time that the inputs were captured at is very small.
+ globalMeasurementCorrect.accept(snapshot.inputs, y);
+ }
+ lastTimestamp = key;
+
+ m_pastObserverSnapshots.set(
+ i,
+ Map.entry(
+ key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements)));
+ }
+ }
+
+ /** This class contains all the information about our observer at a given time. */
+ @SuppressWarnings("MemberName")
+ public class ObserverSnapshot {
+ public final Matrix<S, N1> xHat;
+ public final Matrix<S, S> errorCovariances;
+ public final Matrix<I, N1> inputs;
+ public final Matrix<O, N1> localMeasurements;
+
+ @SuppressWarnings("ParameterName")
+ private ObserverSnapshot(
+ KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) {
+ this.xHat = observer.getXhat();
+ this.errorCovariances = observer.getP();
+
+ inputs = u;
+ localMeasurements = localY;
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanTypeFilter.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanTypeFilter.java
new file mode 100644
index 0000000..3fd3957
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/KalmanTypeFilter.java
@@ -0,0 +1,32 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+
+@SuppressWarnings({"ParameterName", "InterfaceTypeParameterName"})
+interface KalmanTypeFilter<States extends Num, Inputs extends Num, Outputs extends Num> {
+ Matrix<States, States> getP();
+
+ double getP(int i, int j);
+
+ void setP(Matrix<States, States> newP);
+
+ Matrix<States, N1> getXhat();
+
+ double getXhat(int i);
+
+ void setXhat(Matrix<States, N1> xHat);
+
+ void setXhat(int i, double value);
+
+ void reset();
+
+ void predict(Matrix<Inputs, N1> u, double dtSeconds);
+
+ void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y);
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/MecanumDrivePoseEstimator.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/MecanumDrivePoseEstimator.java
new file mode 100644
index 0000000..42a6adb
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/MecanumDrivePoseEstimator.java
@@ -0,0 +1,305 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.kinematics.MecanumDriveKinematics;
+import edu.wpi.first.math.kinematics.MecanumDriveWheelSpeeds;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.util.WPIUtilJNI;
+import java.util.function.BiConsumer;
+
+/**
+ * This class wraps an {@link UnscentedKalmanFilter Unscented Kalman Filter} to fuse
+ * latency-compensated vision measurements with mecanum drive encoder velocity measurements. It will
+ * correct for noisy measurements and encoder drift. It is intended to be an easy but more accurate
+ * drop-in for {@link edu.wpi.first.math.kinematics.MecanumDriveOdometry}.
+ *
+ * <p>{@link MecanumDrivePoseEstimator#update} should be called every robot loop. If your loops are
+ * faster or slower than the default of 0.02s, then you should change the nominal delta time using
+ * the secondary constructor: {@link MecanumDrivePoseEstimator#MecanumDrivePoseEstimator(Rotation2d,
+ * Pose2d, MecanumDriveKinematics, Matrix, Matrix, Matrix, double)}.
+ *
+ * <p>{@link MecanumDrivePoseEstimator#addVisionMeasurement} can be called as infrequently as you
+ * want; if you never call it, then this class will behave mostly like regular encoder odometry.
+ *
+ * <p>The state-space system used internally has the following states (x), inputs (u), and outputs
+ * (y):
+ *
+ * <p><strong> x = [x, y, theta]ᵀ </strong> in the field coordinate system containing x position, y
+ * position, and heading.
+ *
+ * <p><strong> u = [v_l, v_r, dtheta]ᵀ </strong> containing left wheel velocity, right wheel
+ * velocity, and change in gyro heading.
+ *
+ * <p><strong> y = [x, y, theta]ᵀ </strong> from vision containing x position, y position, and
+ * heading; or <strong> y = [theta]ᵀ </strong> containing gyro heading.
+ */
+public class MecanumDrivePoseEstimator {
+ private final UnscentedKalmanFilter<N3, N3, N1> m_observer;
+ private final MecanumDriveKinematics m_kinematics;
+ private final BiConsumer<Matrix<N3, N1>, Matrix<N3, N1>> m_visionCorrect;
+ private final KalmanFilterLatencyCompensator<N3, N3, N1> m_latencyCompensator;
+
+ private final double m_nominalDt; // Seconds
+ private double m_prevTimeSeconds = -1.0;
+
+ private Rotation2d m_gyroOffset;
+ private Rotation2d m_previousAngle;
+
+ private Matrix<N3, N3> m_visionContR;
+
+ /**
+ * Constructs a MecanumDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPoseMeters The starting pose estimate.
+ * @param kinematics A correctly-configured kinematics object for your drivetrain.
+ * @param stateStdDevs Standard deviations of model states. Increase these numbers to trust your
+ * model's state estimates less. This matrix is in the form [x, y, theta]ᵀ, with units in
+ * meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
+ * Increase these numbers to trust sensor readings from encoders and gyros less. This matrix
+ * is in the form [theta], with units in radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public MecanumDrivePoseEstimator(
+ Rotation2d gyroAngle,
+ Pose2d initialPoseMeters,
+ MecanumDriveKinematics kinematics,
+ Matrix<N3, N1> stateStdDevs,
+ Matrix<N1, N1> localMeasurementStdDevs,
+ Matrix<N3, N1> visionMeasurementStdDevs) {
+ this(
+ gyroAngle,
+ initialPoseMeters,
+ kinematics,
+ stateStdDevs,
+ localMeasurementStdDevs,
+ visionMeasurementStdDevs,
+ 0.02);
+ }
+
+ /**
+ * Constructs a MecanumDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPoseMeters The starting pose estimate.
+ * @param kinematics A correctly-configured kinematics object for your drivetrain.
+ * @param stateStdDevs Standard deviations of model states. Increase these numbers to trust your
+ * model's state estimates less. This matrix is in the form [x, y, theta]ᵀ, with units in
+ * meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
+ * Increase these numbers to trust sensor readings from encoders and gyros less. This matrix
+ * is in the form [theta], with units in radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ * @param nominalDtSeconds The time in seconds between each robot loop.
+ */
+ @SuppressWarnings("ParameterName")
+ public MecanumDrivePoseEstimator(
+ Rotation2d gyroAngle,
+ Pose2d initialPoseMeters,
+ MecanumDriveKinematics kinematics,
+ Matrix<N3, N1> stateStdDevs,
+ Matrix<N1, N1> localMeasurementStdDevs,
+ Matrix<N3, N1> visionMeasurementStdDevs,
+ double nominalDtSeconds) {
+ m_nominalDt = nominalDtSeconds;
+
+ m_observer =
+ new UnscentedKalmanFilter<>(
+ Nat.N3(),
+ Nat.N1(),
+ (x, u) -> u,
+ (x, u) -> x.extractRowVector(2),
+ stateStdDevs,
+ localMeasurementStdDevs,
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleMean(0),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleResidual(0),
+ AngleStatistics.angleAdd(2),
+ m_nominalDt);
+ m_kinematics = kinematics;
+ m_latencyCompensator = new KalmanFilterLatencyCompensator<>();
+
+ // Initialize vision R
+ setVisionMeasurementStdDevs(visionMeasurementStdDevs);
+
+ m_visionCorrect =
+ (u, y) ->
+ m_observer.correct(
+ Nat.N3(),
+ u,
+ y,
+ (x, u1) -> x,
+ m_visionContR,
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleAdd(2));
+
+ m_gyroOffset = initialPoseMeters.getRotation().minus(gyroAngle);
+ m_previousAngle = initialPoseMeters.getRotation();
+ m_observer.setXhat(StateSpaceUtil.poseTo3dVector(initialPoseMeters));
+ }
+
+ /**
+ * Sets the pose estimator's trust of global measurements. This might be used to change trust in
+ * vision measurements after the autonomous period, or to change trust as distance to a vision
+ * target increases.
+ *
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public void setVisionMeasurementStdDevs(Matrix<N3, N1> visionMeasurementStdDevs) {
+ m_visionContR = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), visionMeasurementStdDevs);
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>You NEED to reset your encoders (to zero) when calling this method.
+ *
+ * <p>The gyroscope angle does not need to be reset in the user's robot code. The library
+ * automatically takes care of offsetting the gyro angle.
+ *
+ * @param poseMeters The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
+ // Reset state estimate and error covariance
+ m_observer.reset();
+ m_latencyCompensator.reset();
+
+ m_observer.setXhat(StateSpaceUtil.poseTo3dVector(poseMeters));
+
+ m_gyroOffset = getEstimatedPosition().getRotation().minus(gyroAngle);
+ m_previousAngle = poseMeters.getRotation();
+ }
+
+ /**
+ * Gets the pose of the robot at the current time as estimated by the Unscented Kalman Filter.
+ *
+ * @return The estimated robot pose in meters.
+ */
+ public Pose2d getEstimatedPosition() {
+ return new Pose2d(
+ m_observer.getXhat(0), m_observer.getXhat(1), new Rotation2d(m_observer.getXhat(2)));
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct the odometry pose
+ * estimate while still accounting for measurement noise.
+ *
+ * <p>This method can be called as infrequently as you want, as long as you are calling {@link
+ * MecanumDrivePoseEstimator#update} every loop.
+ *
+ * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
+ * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
+ * don't use your own time source by calling {@link MecanumDrivePoseEstimator#updateWithTime}
+ * then you must use a timestamp with an epoch since FPGA startup (i.e. the epoch of this
+ * timestamp is the same epoch as Timer.getFPGATimestamp.) This means that you should use
+ * Timer.getFPGATimestamp as your time source or sync the epochs.
+ */
+ public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
+ m_latencyCompensator.applyPastGlobalMeasurement(
+ Nat.N3(),
+ m_observer,
+ m_nominalDt,
+ StateSpaceUtil.poseTo3dVector(visionRobotPoseMeters),
+ m_visionCorrect,
+ timestampSeconds);
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct the odometry pose
+ * estimate while still accounting for measurement noise.
+ *
+ * <p>This method can be called as infrequently as you want, as long as you are calling {@link
+ * MecanumDrivePoseEstimator#update} every loop.
+ *
+ * <p>Note that the vision measurement standard deviations passed into this method will continue
+ * to apply to future measurements until a subsequent call to {@link
+ * MecanumDrivePoseEstimator#setVisionMeasurementStdDevs(Matrix)} or this method.
+ *
+ * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
+ * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
+ * don't use your own time source by calling {@link MecanumDrivePoseEstimator#updateWithTime}
+ * then you must use a timestamp with an epoch since FPGA startup (i.e. the epoch of this
+ * timestamp is the same epoch as Timer.getFPGATimestamp.) This means that you should use
+ * Timer.getFPGATimestamp as your time source in this case.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public void addVisionMeasurement(
+ Pose2d visionRobotPoseMeters,
+ double timestampSeconds,
+ Matrix<N3, N1> visionMeasurementStdDevs) {
+ setVisionMeasurementStdDevs(visionMeasurementStdDevs);
+ addVisionMeasurement(visionRobotPoseMeters, timestampSeconds);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder information. This should be
+ * called every loop, and the correct loop period must be passed into the constructor of this
+ * class.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param wheelSpeeds The current speeds of the mecanum drive wheels.
+ * @return The estimated pose of the robot in meters.
+ */
+ public Pose2d update(Rotation2d gyroAngle, MecanumDriveWheelSpeeds wheelSpeeds) {
+ return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, wheelSpeeds);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder information. This should be
+ * called every loop, and the correct loop period must be passed into the constructor of this
+ * class.
+ *
+ * @param currentTimeSeconds Time at which this method was called, in seconds.
+ * @param gyroAngle The current gyroscope angle.
+ * @param wheelSpeeds The current speeds of the mecanum drive wheels.
+ * @return The estimated pose of the robot in meters.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public Pose2d updateWithTime(
+ double currentTimeSeconds, Rotation2d gyroAngle, MecanumDriveWheelSpeeds wheelSpeeds) {
+ double dt = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : m_nominalDt;
+ m_prevTimeSeconds = currentTimeSeconds;
+
+ var angle = gyroAngle.plus(m_gyroOffset);
+ var omega = angle.minus(m_previousAngle).getRadians() / dt;
+
+ var chassisSpeeds = m_kinematics.toChassisSpeeds(wheelSpeeds);
+ var fieldRelativeVelocities =
+ new Translation2d(chassisSpeeds.vxMetersPerSecond, chassisSpeeds.vyMetersPerSecond)
+ .rotateBy(angle);
+
+ var u = VecBuilder.fill(fieldRelativeVelocities.getX(), fieldRelativeVelocities.getY(), omega);
+ m_previousAngle = angle;
+
+ var localY = VecBuilder.fill(angle.getRadians());
+ m_latencyCompensator.addObserverState(m_observer, u, localY, currentTimeSeconds);
+ m_observer.predict(u, dt);
+ m_observer.correct(u, localY);
+
+ return getEstimatedPosition();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/MerweScaledSigmaPoints.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java
similarity index 69%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/MerweScaledSigmaPoints.java
rename to wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java
index 56e9288..fb0628b 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/MerweScaledSigmaPoints.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java
@@ -1,35 +1,28 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.estimator;
+package edu.wpi.first.math.estimator;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
/**
- * Generates sigma points and weights according to Van der Merwe's 2004
- * dissertation[1] for the UnscentedKalmanFilter class.
+ * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the
+ * UnscentedKalmanFilter class.
*
- * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the
- * version seen in most publications. Unless you know better, this should be
- * your default choice.
+ * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in
+ * most publications. Unless you know better, this should be your default choice.
*
- * <p>States is the dimensionality of the state. 2*States+1 weights will be
- * generated.
+ * <p>States is the dimensionality of the state. 2*States+1 weights will be generated.
*
- * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic
- * Inference in Dynamic State-Space Models" (Doctoral dissertation)
+ * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic Inference in Dynamic
+ * State-Space Models" (Doctoral dissertation)
*/
public class MerweScaledSigmaPoints<S extends Num> {
-
private final double m_alpha;
private final int m_kappa;
private final Nat<S> m_states;
@@ -40,11 +33,11 @@
* Constructs a generator for Van der Merwe scaled sigma points.
*
* @param states an instance of Num that represents the number of states.
- * @param alpha Determines the spread of the sigma points around the mean.
- * Usually a small positive value (1e-3).
- * @param beta Incorporates prior knowledge of the distribution of the mean.
- * For Gaussian distributions, beta = 2 is optimal.
- * @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
+ * @param alpha Determines the spread of the sigma points around the mean. Usually a small
+ * positive value (1e-3).
+ * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian
+ * distributions, beta = 2 is optimal.
+ * @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
*/
public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) {
this.m_states = states;
@@ -74,27 +67,24 @@
}
/**
- * Computes the sigma points for an unscented Kalman filter given the mean
- * (x) and covariance(P) of the filter.
+ * Computes the sigma points for an unscented Kalman filter given the mean (x) and covariance(P)
+ * of the filter.
*
* @param x An array of the means.
* @param P Covariance of the filter.
- * @return Two dimensional array of sigma points. Each column contains all of
- * the sigmas for one dimension in the problem space. Ordered by
- * Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
+ * @return Two dimensional array of sigma points. Each column contains all of the sigmas for one
+ * dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
- public Matrix<S, ?> sigmaPoints(
- Matrix<S, N1> x,
- Matrix<S, S> P) {
+ public Matrix<S, ?> sigmaPoints(Matrix<S, N1> x, Matrix<S, S> P) {
double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
var intermediate = P.times(lambda + m_states.getNum());
var U = intermediate.lltDecompose(true); // Lower triangular
// 2 * states + 1 by states
- Matrix<S, ?> sigmas = new Matrix<>(
- new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
+ Matrix<S, ?> sigmas =
+ new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
sigmas.setColumn(0, x);
for (int k = 0; k < m_states.getNum(); k++) {
var xPlusU = x.plus(U.extractColumnVector(k));
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/SwerveDrivePoseEstimator.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/SwerveDrivePoseEstimator.java
new file mode 100644
index 0000000..9e48728
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/SwerveDrivePoseEstimator.java
@@ -0,0 +1,305 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.kinematics.SwerveDriveKinematics;
+import edu.wpi.first.math.kinematics.SwerveModuleState;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.util.WPIUtilJNI;
+import java.util.function.BiConsumer;
+
+/**
+ * This class wraps an {@link UnscentedKalmanFilter Unscented Kalman Filter} to fuse
+ * latency-compensated vision measurements with swerve drive encoder velocity measurements. It will
+ * correct for noisy measurements and encoder drift. It is intended to be an easy but more accurate
+ * drop-in for {@link edu.wpi.first.math.kinematics.SwerveDriveOdometry}.
+ *
+ * <p>{@link SwerveDrivePoseEstimator#update} should be called every robot loop. If your loops are
+ * faster or slower than the default of 0.02s, then you should change the nominal delta time using
+ * the secondary constructor: {@link SwerveDrivePoseEstimator#SwerveDrivePoseEstimator(Rotation2d,
+ * Pose2d, SwerveDriveKinematics, Matrix, Matrix, Matrix, double)}.
+ *
+ * <p>{@link SwerveDrivePoseEstimator#addVisionMeasurement} can be called as infrequently as you
+ * want; if you never call it, then this class will behave mostly like regular encoder odometry.
+ *
+ * <p>The state-space system used internally has the following states (x), inputs (u), and outputs
+ * (y):
+ *
+ * <p><strong> x = [x, y, theta]ᵀ </strong> in the field coordinate system containing x position, y
+ * position, and heading.
+ *
+ * <p><strong> u = [v_l, v_r, dtheta]ᵀ </strong> containing left wheel velocity, right wheel
+ * velocity, and change in gyro heading.
+ *
+ * <p><strong> y = [x, y, theta]ᵀ </strong> from vision containing x position, y position, and
+ * heading; or <strong> y = [theta]ᵀ </strong> containing gyro heading.
+ */
+public class SwerveDrivePoseEstimator {
+ private final UnscentedKalmanFilter<N3, N3, N1> m_observer;
+ private final SwerveDriveKinematics m_kinematics;
+ private final BiConsumer<Matrix<N3, N1>, Matrix<N3, N1>> m_visionCorrect;
+ private final KalmanFilterLatencyCompensator<N3, N3, N1> m_latencyCompensator;
+
+ private final double m_nominalDt; // Seconds
+ private double m_prevTimeSeconds = -1.0;
+
+ private Rotation2d m_gyroOffset;
+ private Rotation2d m_previousAngle;
+
+ private Matrix<N3, N3> m_visionContR;
+
+ /**
+ * Constructs a SwerveDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPoseMeters The starting pose estimate.
+ * @param kinematics A correctly-configured kinematics object for your drivetrain.
+ * @param stateStdDevs Standard deviations of model states. Increase these numbers to trust your
+ * model's state estimates less. This matrix is in the form [x, y, theta]ᵀ, with units in
+ * meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
+ * Increase these numbers to trust sensor readings from encoders and gyros less. This matrix
+ * is in the form [theta], with units in radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public SwerveDrivePoseEstimator(
+ Rotation2d gyroAngle,
+ Pose2d initialPoseMeters,
+ SwerveDriveKinematics kinematics,
+ Matrix<N3, N1> stateStdDevs,
+ Matrix<N1, N1> localMeasurementStdDevs,
+ Matrix<N3, N1> visionMeasurementStdDevs) {
+ this(
+ gyroAngle,
+ initialPoseMeters,
+ kinematics,
+ stateStdDevs,
+ localMeasurementStdDevs,
+ visionMeasurementStdDevs,
+ 0.02);
+ }
+
+ /**
+ * Constructs a SwerveDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPoseMeters The starting pose estimate.
+ * @param kinematics A correctly-configured kinematics object for your drivetrain.
+ * @param stateStdDevs Standard deviations of model states. Increase these numbers to trust your
+ * model's state estimates less. This matrix is in the form [x, y, theta]ᵀ, with units in
+ * meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
+ * Increase these numbers to trust sensor readings from encoders and gyros less. This matrix
+ * is in the form [theta], with units in radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ * @param nominalDtSeconds The time in seconds between each robot loop.
+ */
+ @SuppressWarnings("ParameterName")
+ public SwerveDrivePoseEstimator(
+ Rotation2d gyroAngle,
+ Pose2d initialPoseMeters,
+ SwerveDriveKinematics kinematics,
+ Matrix<N3, N1> stateStdDevs,
+ Matrix<N1, N1> localMeasurementStdDevs,
+ Matrix<N3, N1> visionMeasurementStdDevs,
+ double nominalDtSeconds) {
+ m_nominalDt = nominalDtSeconds;
+
+ m_observer =
+ new UnscentedKalmanFilter<>(
+ Nat.N3(),
+ Nat.N1(),
+ (x, u) -> u,
+ (x, u) -> x.extractRowVector(2),
+ stateStdDevs,
+ localMeasurementStdDevs,
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleMean(0),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleResidual(0),
+ AngleStatistics.angleAdd(2),
+ m_nominalDt);
+ m_kinematics = kinematics;
+ m_latencyCompensator = new KalmanFilterLatencyCompensator<>();
+
+ // Initialize vision R
+ setVisionMeasurementStdDevs(visionMeasurementStdDevs);
+
+ m_visionCorrect =
+ (u, y) ->
+ m_observer.correct(
+ Nat.N3(),
+ u,
+ y,
+ (x, u1) -> x,
+ m_visionContR,
+ AngleStatistics.angleMean(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleResidual(2),
+ AngleStatistics.angleAdd(2));
+
+ m_gyroOffset = initialPoseMeters.getRotation().minus(gyroAngle);
+ m_previousAngle = initialPoseMeters.getRotation();
+ m_observer.setXhat(StateSpaceUtil.poseTo3dVector(initialPoseMeters));
+ }
+
+ /**
+ * Sets the pose estimator's trust of global measurements. This might be used to change trust in
+ * vision measurements after the autonomous period, or to change trust as distance to a vision
+ * target increases.
+ *
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public void setVisionMeasurementStdDevs(Matrix<N3, N1> visionMeasurementStdDevs) {
+ m_visionContR = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), visionMeasurementStdDevs);
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>You NEED to reset your encoders (to zero) when calling this method.
+ *
+ * <p>The gyroscope angle does not need to be reset in the user's robot code. The library
+ * automatically takes care of offsetting the gyro angle.
+ *
+ * @param poseMeters The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
+ // Reset state estimate and error covariance
+ m_observer.reset();
+ m_latencyCompensator.reset();
+
+ m_observer.setXhat(StateSpaceUtil.poseTo3dVector(poseMeters));
+
+ m_gyroOffset = getEstimatedPosition().getRotation().minus(gyroAngle);
+ m_previousAngle = poseMeters.getRotation();
+ }
+
+ /**
+ * Gets the pose of the robot at the current time as estimated by the Unscented Kalman Filter.
+ *
+ * @return The estimated robot pose in meters.
+ */
+ public Pose2d getEstimatedPosition() {
+ return new Pose2d(
+ m_observer.getXhat(0), m_observer.getXhat(1), new Rotation2d(m_observer.getXhat(2)));
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct the odometry pose
+ * estimate while still accounting for measurement noise.
+ *
+ * <p>This method can be called as infrequently as you want, as long as you are calling {@link
+ * SwerveDrivePoseEstimator#update} every loop.
+ *
+ * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
+ * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
+ * don't use your own time source by calling {@link SwerveDrivePoseEstimator#updateWithTime}
+ * then you must use a timestamp with an epoch since FPGA startup (i.e. the epoch of this
+ * timestamp is the same epoch as Timer.getFPGATimestamp.) This means that you should use
+ * Timer.getFPGATimestamp as your time source or sync the epochs.
+ */
+ public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
+ m_latencyCompensator.applyPastGlobalMeasurement(
+ Nat.N3(),
+ m_observer,
+ m_nominalDt,
+ StateSpaceUtil.poseTo3dVector(visionRobotPoseMeters),
+ m_visionCorrect,
+ timestampSeconds);
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct the odometry pose
+ * estimate while still accounting for measurement noise.
+ *
+ * <p>This method can be called as infrequently as you want, as long as you are calling {@link
+ * SwerveDrivePoseEstimator#update} every loop.
+ *
+ * <p>Note that the vision measurement standard deviations passed into this method will continue
+ * to apply to future measurements until a subsequent call to {@link
+ * SwerveDrivePoseEstimator#setVisionMeasurementStdDevs(Matrix)} or this method.
+ *
+ * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
+ * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
+ * don't use your own time source by calling {@link SwerveDrivePoseEstimator#updateWithTime}
+ * then you must use a timestamp with an epoch since FPGA startup (i.e. the epoch of this
+ * timestamp is the same epoch as Timer.getFPGATimestamp.) This means that you should use
+ * Timer.getFPGATimestamp as your time source in this case.
+ * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
+ * numbers to trust global measurements from vision less. This matrix is in the form [x, y,
+ * theta]ᵀ, with units in meters and radians.
+ */
+ public void addVisionMeasurement(
+ Pose2d visionRobotPoseMeters,
+ double timestampSeconds,
+ Matrix<N3, N1> visionMeasurementStdDevs) {
+ setVisionMeasurementStdDevs(visionMeasurementStdDevs);
+ addVisionMeasurement(visionRobotPoseMeters, timestampSeconds);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder information. This should be
+ * called every loop, and the correct loop period must be passed into the constructor of this
+ * class.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param moduleStates The current velocities and rotations of the swerve modules.
+ * @return The estimated pose of the robot in meters.
+ */
+ public Pose2d update(Rotation2d gyroAngle, SwerveModuleState... moduleStates) {
+ return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, moduleStates);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder information. This should be
+ * called every loop, and the correct loop period must be passed into the constructor of this
+ * class.
+ *
+ * @param currentTimeSeconds Time at which this method was called, in seconds.
+ * @param gyroAngle The current gyroscope angle.
+ * @param moduleStates The current velocities and rotations of the swerve modules.
+ * @return The estimated pose of the robot in meters.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public Pose2d updateWithTime(
+ double currentTimeSeconds, Rotation2d gyroAngle, SwerveModuleState... moduleStates) {
+ double dt = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : m_nominalDt;
+ m_prevTimeSeconds = currentTimeSeconds;
+
+ var angle = gyroAngle.plus(m_gyroOffset);
+ var omega = angle.minus(m_previousAngle).getRadians() / dt;
+
+ var chassisSpeeds = m_kinematics.toChassisSpeeds(moduleStates);
+ var fieldRelativeVelocities =
+ new Translation2d(chassisSpeeds.vxMetersPerSecond, chassisSpeeds.vyMetersPerSecond)
+ .rotateBy(angle);
+
+ var u = VecBuilder.fill(fieldRelativeVelocities.getX(), fieldRelativeVelocities.getY(), omega);
+ m_previousAngle = angle;
+
+ var localY = VecBuilder.fill(angle.getRadians());
+ m_latencyCompensator.addObserverState(m_observer, u, localY, currentTimeSeconds);
+ m_observer.predict(u, dt);
+ m_observer.correct(u, localY);
+
+ return getEstimatedPosition();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java
new file mode 100644
index 0000000..ffc2c15
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java
@@ -0,0 +1,436 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.Pair;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.system.Discretization;
+import edu.wpi.first.math.system.NumericalIntegration;
+import edu.wpi.first.math.system.NumericalJacobian;
+import java.util.function.BiFunction;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * A Kalman filter combines predictions from a model and measurements to give an estimate of the
+ * true system state. This is useful because many states cannot be measured directly as a result of
+ * sensor noise, or because the state is "hidden".
+ *
+ * <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
+ * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
+ * of squares error in the state estimate. This K gain is used to correct the state estimate by some
+ * amount of the difference between the actual measurements and the measurements predicted by the
+ * model.
+ *
+ * <p>An unscented Kalman filter uses nonlinear state and measurement models. It propagates the
+ * error covariance using sigma points chosen to approximate the true probability distribution.
+ *
+ * <p>For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
+ * theory".
+ */
+@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
+public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
+ implements KalmanTypeFilter<States, Inputs, Outputs> {
+ private final Nat<States> m_states;
+ private final Nat<Outputs> m_outputs;
+
+ private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
+ private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
+
+ private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX;
+ private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY;
+ private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX;
+ private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
+ private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
+
+ private Matrix<States, N1> m_xHat;
+ private Matrix<States, States> m_P;
+ private final Matrix<States, States> m_contQ;
+ private final Matrix<Outputs, Outputs> m_contR;
+ private Matrix<States, ?> m_sigmasF;
+ private double m_dtSeconds;
+
+ private final MerweScaledSigmaPoints<States> m_pts;
+
+ /**
+ * Constructs an Unscented Kalman Filter.
+ *
+ * @param states A Nat representing the number of states.
+ * @param outputs A Nat representing the number of outputs.
+ * @param f A vector-valued function of x and u that returns the derivative of the state vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param stateStdDevs Standard deviations of model states.
+ * @param measurementStdDevs Standard deviations of measurements.
+ * @param nominalDtSeconds Nominal discretization timestep.
+ */
+ @SuppressWarnings("LambdaParameterName")
+ public UnscentedKalmanFilter(
+ Nat<States> states,
+ Nat<Outputs> outputs,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
+ Matrix<States, N1> stateStdDevs,
+ Matrix<Outputs, N1> measurementStdDevs,
+ double nominalDtSeconds) {
+ this(
+ states,
+ outputs,
+ f,
+ h,
+ stateStdDevs,
+ measurementStdDevs,
+ (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
+ (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
+ Matrix::minus,
+ Matrix::minus,
+ Matrix::plus,
+ nominalDtSeconds);
+ }
+
+ /**
+ * Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using
+ * custom functions for arithmetic can be useful if you have angles in the state or measurements,
+ * because they allow you to correctly account for the modular nature of angle arithmetic.
+ *
+ * @param states A Nat representing the number of states.
+ * @param outputs A Nat representing the number of outputs.
+ * @param f A vector-valued function of x and u that returns the derivative of the state vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param stateStdDevs Standard deviations of model states.
+ * @param measurementStdDevs Standard deviations of measurements.
+ * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a
+ * given set of weights.
+ * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
+ * a given set of weights.
+ * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
+ * subtracts them.)
+ * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
+ * subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ * @param nominalDtSeconds Nominal discretization timestep.
+ */
+ @SuppressWarnings("ParameterName")
+ public UnscentedKalmanFilter(
+ Nat<States> states,
+ Nat<Outputs> outputs,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
+ Matrix<States, N1> stateStdDevs,
+ Matrix<Outputs, N1> measurementStdDevs,
+ BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
+ BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
+ BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
+ double nominalDtSeconds) {
+ this.m_states = states;
+ this.m_outputs = outputs;
+
+ m_f = f;
+ m_h = h;
+
+ m_meanFuncX = meanFuncX;
+ m_meanFuncY = meanFuncY;
+ m_residualFuncX = residualFuncX;
+ m_residualFuncY = residualFuncY;
+ m_addFuncX = addFuncX;
+
+ m_dtSeconds = nominalDtSeconds;
+
+ m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
+ m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
+
+ m_pts = new MerweScaledSigmaPoints<>(states);
+
+ reset();
+ }
+
+ @SuppressWarnings({"ParameterName", "LocalVariableName"})
+ static <S extends Num, C extends Num> Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
+ Nat<S> s,
+ Nat<C> dim,
+ Matrix<C, ?> sigmas,
+ Matrix<?, N1> Wm,
+ Matrix<?, N1> Wc,
+ BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
+ BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc) {
+ if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
+ throw new IllegalArgumentException(
+ "Sigmas must be covDim by 2 * states + 1! Got "
+ + sigmas.getNumRows()
+ + " by "
+ + sigmas.getNumCols());
+ }
+
+ if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) {
+ throw new IllegalArgumentException(
+ "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols());
+ }
+
+ if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
+ throw new IllegalArgumentException(
+ "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
+ }
+
+ // New mean is usually just the sum of the sigmas * weight:
+ // n
+ // dot = Σ W[k] Xᵢ[k]
+ // k=1
+ Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
+
+ // New covariance is the sum of the outer product of the residuals times the
+ // weights
+ Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
+ for (int i = 0; i < 2 * s.getNum() + 1; i++) {
+ // y[:, i] = sigmas[:, i] - x
+ y.setColumn(i, residualFunc.apply(sigmas.extractColumnVector(i), x));
+ }
+ Matrix<C, C> P =
+ y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
+ .times(Matrix.changeBoundsUnchecked(y.transpose()));
+
+ return new Pair<>(x, P);
+ }
+
+ /**
+ * Returns the error covariance matrix P.
+ *
+ * @return the error covariance matrix P.
+ */
+ @Override
+ public Matrix<States, States> getP() {
+ return m_P;
+ }
+
+ /**
+ * Returns an element of the error covariance matrix P.
+ *
+ * @param row Row of P.
+ * @param col Column of P.
+ * @return the value of the error covariance matrix P at (i, j).
+ */
+ @Override
+ public double getP(int row, int col) {
+ return m_P.get(row, col);
+ }
+
+ /**
+ * Sets the entire error covariance matrix P.
+ *
+ * @param newP The new value of P to use.
+ */
+ @Override
+ public void setP(Matrix<States, States> newP) {
+ m_P = newP;
+ }
+
+ /**
+ * Returns the state estimate x-hat.
+ *
+ * @return the state estimate x-hat.
+ */
+ @Override
+ public Matrix<States, N1> getXhat() {
+ return m_xHat;
+ }
+
+ /**
+ * Returns an element of the state estimate x-hat.
+ *
+ * @param row Row of x-hat.
+ * @return the value of the state estimate x-hat at i.
+ */
+ @Override
+ public double getXhat(int row) {
+ return m_xHat.get(row, 0);
+ }
+
+ /**
+ * Set initial state estimate x-hat.
+ *
+ * @param xHat The state estimate x-hat.
+ */
+ @SuppressWarnings("ParameterName")
+ @Override
+ public void setXhat(Matrix<States, N1> xHat) {
+ m_xHat = xHat;
+ }
+
+ /**
+ * Set an element of the initial state estimate x-hat.
+ *
+ * @param row Row of x-hat.
+ * @param value Value for element of x-hat.
+ */
+ @Override
+ public void setXhat(int row, double value) {
+ m_xHat.set(row, 0, value);
+ }
+
+ /** Resets the observer. */
+ @Override
+ public void reset() {
+ m_xHat = new Matrix<>(m_states, Nat.N1());
+ m_P = new Matrix<>(m_states, m_states);
+ m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
+ }
+
+ /**
+ * Project the model into the future with a new control input u.
+ *
+ * @param u New control input from controller.
+ * @param dtSeconds Timestep for prediction.
+ */
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ @Override
+ public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
+ // Discretize Q before projecting mean and covariance forward
+ Matrix<States, States> contA =
+ NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
+ var discQ = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond();
+
+ var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
+
+ for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
+ Matrix<States, N1> x = sigmas.extractColumnVector(i);
+
+ m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds));
+ }
+
+ var ret =
+ unscentedTransform(
+ m_states,
+ m_states,
+ m_sigmasF,
+ m_pts.getWm(),
+ m_pts.getWc(),
+ m_meanFuncX,
+ m_residualFuncX);
+
+ m_xHat = ret.getFirst();
+ m_P = ret.getSecond().plus(discQ);
+ m_dtSeconds = dtSeconds;
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ */
+ @SuppressWarnings("ParameterName")
+ @Override
+ public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
+ correct(
+ m_outputs, u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
+ * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
+ * of this function).
+ *
+ * @param <R> Number of measurements in y.
+ * @param rows Number of rows in y.
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param R Measurement noise covariance matrix (continuous-time).
+ */
+ @SuppressWarnings({"ParameterName", "LambdaParameterName", "LocalVariableName"})
+ public <R extends Num> void correct(
+ Nat<R> rows,
+ Matrix<Inputs, N1> u,
+ Matrix<R, N1> y,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
+ Matrix<R, R> R) {
+ BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY =
+ (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm));
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX =
+ Matrix::minus;
+ BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY = Matrix::minus;
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX = Matrix::plus;
+ correct(rows, u, y, h, R, meanFuncY, residualFuncY, residualFuncX, addFuncX);
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * <p>This is useful for when the measurements available during a timestep's Correct() call vary.
+ * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
+ * of this function).
+ *
+ * @param <R> Number of measurements in y.
+ * @param rows Number of rows in y.
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ * @param h A vector-valued function of x and u that returns the measurement vector.
+ * @param R Measurement noise covariance matrix (continuous-time).
+ * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
+ * a given set of weights.
+ * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
+ * subtracts them.)
+ * @param residualFuncX A function that computes the residual of two state vectors (i.e. it
+ * subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ */
+ @SuppressWarnings({"ParameterName", "LocalVariableName"})
+ public <R extends Num> void correct(
+ Nat<R> rows,
+ Matrix<Inputs, N1> u,
+ Matrix<R, N1> y,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
+ Matrix<R, R> R,
+ BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY,
+ BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY,
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
+ BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
+ final var discR = Discretization.discretizeR(R, m_dtSeconds);
+
+ // Transform sigma points into measurement space
+ Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
+ var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
+ for (int i = 0; i < m_pts.getNumSigmas(); i++) {
+ Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
+ sigmasH.setColumn(i, hRet);
+ }
+
+ // Mean and covariance of prediction passed through unscented transform
+ var transRet =
+ unscentedTransform(
+ m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc(), meanFuncY, residualFuncY);
+ var yHat = transRet.getFirst();
+ var Py = transRet.getSecond().plus(discR);
+
+ // Compute cross covariance of the state and the measurements
+ Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
+ for (int i = 0; i < m_pts.getNumSigmas(); i++) {
+ // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
+ var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat);
+ var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose();
+
+ Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
+ }
+
+ // K = P_{xy} P_y⁻¹
+ // Kᵀ = P_yᵀ⁻¹ P_{xy}ᵀ
+ // P_yᵀKᵀ = P_{xy}ᵀ
+ // Kᵀ = P_yᵀ.solve(P_{xy}ᵀ)
+ // K = (P_yᵀ.solve(P_{xy}ᵀ)ᵀ
+ Matrix<States, R> K = new Matrix<>(Py.transpose().solve(Pxy.transpose()).transpose());
+
+ // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
+ m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
+
+ // Pₖ₊₁⁺ = Pₖ₊₁⁻ − KP_yKᵀ
+ m_P = m_P.minus(K.times(Py).times(K.transpose()));
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java b/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java
new file mode 100644
index 0000000..93d6bea
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java
@@ -0,0 +1,267 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.filter;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.util.CircularBuffer;
+import java.util.Arrays;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * This class implements a linear, digital filter. All types of FIR and IIR filters are supported.
+ * Static factory methods are provided to create commonly used types of filters.
+ *
+ * <p>Filters are of the form: y[n] = (b0 x[n] + b1 x[n-1] + ... + bP x[n-P]) - (a0 y[n-1] + a2
+ * y[n-2] + ... + aQ y[n-Q])
+ *
+ * <p>Where: y[n] is the output at time "n" x[n] is the input at time "n" y[n-1] is the output from
+ * the LAST time step ("n-1") x[n-1] is the input from the LAST time step ("n-1") b0...bP are the
+ * "feedforward" (FIR) gains a0...aQ are the "feedback" (IIR) gains IMPORTANT! Note the "-" sign in
+ * front of the feedback term! This is a common convention in signal processing.
+ *
+ * <p>What can linear filters do? Basically, they can filter, or diminish, the effects of
+ * undesirable input frequencies. High frequencies, or rapid changes, can be indicative of sensor
+ * noise or be otherwise undesirable. A "low pass" filter smooths out the signal, reducing the
+ * impact of these high frequency components. Likewise, a "high pass" filter gets rid of slow-moving
+ * signal components, letting you detect large changes more easily.
+ *
+ * <p>Example FRC applications of filters: - Getting rid of noise from an analog sensor input (note:
+ * the roboRIO's FPGA can do this faster in hardware) - Smoothing out joystick input to prevent the
+ * wheels from slipping or the robot from tipping - Smoothing motor commands so that unnecessary
+ * strain isn't put on electrical or mechanical components - If you use clever gains, you can make a
+ * PID controller out of this class!
+ *
+ * <p>For more on filters, we highly recommend the following articles:<br>
+ * https://en.wikipedia.org/wiki/Linear_filter<br>
+ * https://en.wikipedia.org/wiki/Iir_filter<br>
+ * https://en.wikipedia.org/wiki/Fir_filter<br>
+ *
+ * <p>Note 1: calculate() should be called by the user on a known, regular period. You can use a
+ * Notifier for this or do it "inline" with code in a periodic function.
+ *
+ * <p>Note 2: For ALL filters, gains are necessarily a function of frequency. If you make a filter
+ * that works well for you at, say, 100Hz, you will most definitely need to adjust the gains if you
+ * then want to run it at 200Hz! Combining this with Note 1 - the impetus is on YOU as a developer
+ * to make sure calculate() gets called at the desired, constant frequency!
+ */
+public class LinearFilter {
+ private final CircularBuffer m_inputs;
+ private final CircularBuffer m_outputs;
+ private final double[] m_inputGains;
+ private final double[] m_outputGains;
+
+ private static int instances;
+
+ /**
+ * Create a linear FIR or IIR filter.
+ *
+ * @param ffGains The "feedforward" or FIR gains.
+ * @param fbGains The "feedback" or IIR gains.
+ */
+ public LinearFilter(double[] ffGains, double[] fbGains) {
+ m_inputs = new CircularBuffer(ffGains.length);
+ m_outputs = new CircularBuffer(fbGains.length);
+ m_inputGains = Arrays.copyOf(ffGains, ffGains.length);
+ m_outputGains = Arrays.copyOf(fbGains, fbGains.length);
+
+ instances++;
+ MathSharedStore.reportUsage(MathUsageId.kFilter_Linear, instances);
+ }
+
+ /**
+ * Creates a one-pole IIR low-pass filter of the form: y[n] = (1-gain) x[n] + gain y[n-1] where
+ * gain = e<sup>-dt / T</sup>, T is the time constant in seconds.
+ *
+ * <p>Note: T = 1 / (2 pi f) where f is the cutoff frequency in Hz, the frequency above which the
+ * input starts to attenuate.
+ *
+ * <p>This filter is stable for time constants greater than zero.
+ *
+ * @param timeConstant The discrete-time time constant in seconds.
+ * @param period The period in seconds between samples taken by the user.
+ * @return Linear filter.
+ */
+ public static LinearFilter singlePoleIIR(double timeConstant, double period) {
+ double gain = Math.exp(-period / timeConstant);
+ double[] ffGains = {1.0 - gain};
+ double[] fbGains = {-gain};
+
+ return new LinearFilter(ffGains, fbGains);
+ }
+
+ /**
+ * Creates a first-order high-pass filter of the form: y[n] = gain x[n] + (-gain) x[n-1] + gain
+ * y[n-1] where gain = e<sup>-dt / T</sup>, T is the time constant in seconds.
+ *
+ * <p>Note: T = 1 / (2 pi f) where f is the cutoff frequency in Hz, the frequency below which the
+ * input starts to attenuate.
+ *
+ * <p>This filter is stable for time constants greater than zero.
+ *
+ * @param timeConstant The discrete-time time constant in seconds.
+ * @param period The period in seconds between samples taken by the user.
+ * @return Linear filter.
+ */
+ public static LinearFilter highPass(double timeConstant, double period) {
+ double gain = Math.exp(-period / timeConstant);
+ double[] ffGains = {gain, -gain};
+ double[] fbGains = {-gain};
+
+ return new LinearFilter(ffGains, fbGains);
+ }
+
+ /**
+ * Creates a K-tap FIR moving average filter of the form: y[n] = 1/k (x[k] + x[k-1] + ... + x[0]).
+ *
+ * <p>This filter is always stable.
+ *
+ * @param taps The number of samples to average over. Higher = smoother but slower.
+ * @return Linear filter.
+ * @throws IllegalArgumentException if number of taps is less than 1.
+ */
+ public static LinearFilter movingAverage(int taps) {
+ if (taps <= 0) {
+ throw new IllegalArgumentException("Number of taps was not at least 1");
+ }
+
+ double[] ffGains = new double[taps];
+ for (int i = 0; i < ffGains.length; i++) {
+ ffGains[i] = 1.0 / taps;
+ }
+
+ double[] fbGains = new double[0];
+
+ return new LinearFilter(ffGains, fbGains);
+ }
+
+ /**
+ * Creates a backward finite difference filter that computes the nth derivative of the input given
+ * the specified number of samples.
+ *
+ * <p>For example, a first derivative filter that uses two samples and a sample period of 20 ms
+ * would be
+ *
+ * <pre><code>
+ * LinearFilter.backwardFiniteDifference(1, 2, 0.02);
+ * </code></pre>
+ *
+ * @param derivative The order of the derivative to compute.
+ * @param samples The number of samples to use to compute the given derivative. This must be one
+ * more than the order of derivative or higher.
+ * @param period The period in seconds between samples taken by the user.
+ * @return Linear filter.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public static LinearFilter backwardFiniteDifference(int derivative, int samples, double period) {
+ // See
+ // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points
+ //
+ // <p>For a given list of stencil points s of length n and the order of
+ // derivative d < n, the finite difference coefficients can be obtained by
+ // solving the following linear system for the vector a.
+ //
+ // <pre>
+ // [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ]
+ // [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ]
+ // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d]
+ // </pre>
+ //
+ // <p>where δᵢ,ⱼ are the Kronecker delta. For backward finite difference,
+ // the stencil points are the range [-n + 1, 0]. The FIR gains are the
+ // elements of the vector a in reverse order divided by hᵈ.
+ //
+ // <p>The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ).
+
+ if (derivative < 1) {
+ throw new IllegalArgumentException(
+ "Order of derivative must be greater than or equal to one.");
+ }
+
+ if (samples <= 0) {
+ throw new IllegalArgumentException("Number of samples must be greater than zero.");
+ }
+
+ if (derivative >= samples) {
+ throw new IllegalArgumentException(
+ "Order of derivative must be less than number of samples.");
+ }
+
+ var S = new SimpleMatrix(samples, samples);
+ for (int row = 0; row < samples; ++row) {
+ for (int col = 0; col < samples; ++col) {
+ double s = 1 - samples + col;
+ S.set(row, col, Math.pow(s, row));
+ }
+ }
+
+ // Fill in Kronecker deltas: https://en.wikipedia.org/wiki/Kronecker_delta
+ var d = new SimpleMatrix(samples, 1);
+ for (int i = 0; i < samples; ++i) {
+ d.set(i, 0, (i == derivative) ? factorial(derivative) : 0.0);
+ }
+
+ var a = S.solve(d).divide(Math.pow(period, derivative));
+
+ // Reverse gains list
+ double[] ffGains = new double[samples];
+ for (int i = 0; i < samples; ++i) {
+ ffGains[i] = a.get(samples - i - 1, 0);
+ }
+
+ double[] fbGains = new double[0];
+
+ return new LinearFilter(ffGains, fbGains);
+ }
+
+ /** Reset the filter state. */
+ public void reset() {
+ m_inputs.clear();
+ m_outputs.clear();
+ }
+
+ /**
+ * Calculates the next value of the filter.
+ *
+ * @param input Current input value.
+ * @return The filtered value at this step
+ */
+ public double calculate(double input) {
+ double retVal = 0.0;
+
+ // Rotate the inputs
+ if (m_inputGains.length > 0) {
+ m_inputs.addFirst(input);
+ }
+
+ // Calculate the new value
+ for (int i = 0; i < m_inputGains.length; i++) {
+ retVal += m_inputs.get(i) * m_inputGains[i];
+ }
+ for (int i = 0; i < m_outputGains.length; i++) {
+ retVal -= m_outputs.get(i) * m_outputGains[i];
+ }
+
+ // Rotate the outputs
+ if (m_outputGains.length > 0) {
+ m_outputs.addFirst(retVal);
+ }
+
+ return retVal;
+ }
+
+ /**
+ * Factorial of n.
+ *
+ * @param n Argument of which to take factorial.
+ */
+ private static int factorial(int n) {
+ if (n < 2) {
+ return 1;
+ } else {
+ return n * factorial(n - 1);
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/MedianFilter.java b/wpimath/src/main/java/edu/wpi/first/math/filter/MedianFilter.java
similarity index 69%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/MedianFilter.java
rename to wpimath/src/main/java/edu/wpi/first/math/filter/MedianFilter.java
index 18998b0..c24f6e9 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/MedianFilter.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/filter/MedianFilter.java
@@ -1,22 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj;
+package edu.wpi.first.math.filter;
+import edu.wpi.first.util.CircularBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import edu.wpi.first.wpiutil.CircularBuffer;
-
/**
- * A class that implements a moving-window median filter. Useful for reducing measurement noise,
- * especially with processes that generate occasional, extreme outliers (such as values from
- * vision processing, LIDAR, or ultrasonic sensors).
+ * A class that implements a moving-window median filter. Useful for reducing measurement noise,
+ * especially with processes that generate occasional, extreme outliers (such as values from vision
+ * processing, LIDAR, or ultrasonic sensors).
*/
public class MedianFilter {
private final CircularBuffer m_valueBuffer;
@@ -61,13 +57,13 @@
// and remove from ordered list
if (curSize > m_size) {
m_orderedValues.remove(m_valueBuffer.removeLast());
- curSize = curSize - 1;
+ --curSize;
}
// Add next value to circular buffer
m_valueBuffer.addFirst(next);
- if (curSize % 2 == 1) {
+ if (curSize % 2 != 0) {
// If size is odd, return middle element of sorted list
return m_orderedValues.get(curSize / 2);
} else {
@@ -76,9 +72,7 @@
}
}
- /**
- * Resets the filter, clearing the window of all elements.
- */
+ /** Resets the filter, clearing the window of all elements. */
public void reset() {
m_orderedValues.clear();
m_valueBuffer.clear();
diff --git a/wpimath/src/main/java/edu/wpi/first/math/filter/SlewRateLimiter.java b/wpimath/src/main/java/edu/wpi/first/math/filter/SlewRateLimiter.java
new file mode 100644
index 0000000..d3aa7d8
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/filter/SlewRateLimiter.java
@@ -0,0 +1,66 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.filter;
+
+import edu.wpi.first.math.MathUtil;
+import edu.wpi.first.util.WPIUtilJNI;
+
+/**
+ * A class that limits the rate of change of an input value. Useful for implementing voltage,
+ * setpoint, and/or output ramps. A slew-rate limit is most appropriate when the quantity being
+ * controlled is a velocity or a voltage; when controlling a position, consider using a {@link
+ * edu.wpi.first.math.trajectory.TrapezoidProfile} instead.
+ */
+public class SlewRateLimiter {
+ private final double m_rateLimit;
+ private double m_prevVal;
+ private double m_prevTime;
+
+ /**
+ * Creates a new SlewRateLimiter with the given rate limit and initial value.
+ *
+ * @param rateLimit The rate-of-change limit, in units per second.
+ * @param initialValue The initial value of the input.
+ */
+ public SlewRateLimiter(double rateLimit, double initialValue) {
+ m_rateLimit = rateLimit;
+ m_prevVal = initialValue;
+ m_prevTime = WPIUtilJNI.now() * 1e-6;
+ }
+
+ /**
+ * Creates a new SlewRateLimiter with the given rate limit and an initial value of zero.
+ *
+ * @param rateLimit The rate-of-change limit, in units per second.
+ */
+ public SlewRateLimiter(double rateLimit) {
+ this(rateLimit, 0);
+ }
+
+ /**
+ * Filters the input to limit its slew rate.
+ *
+ * @param input The input value whose slew rate is to be limited.
+ * @return The filtered value, which will not change faster than the slew rate.
+ */
+ public double calculate(double input) {
+ double currentTime = WPIUtilJNI.now() * 1e-6;
+ double elapsedTime = currentTime - m_prevTime;
+ m_prevVal +=
+ MathUtil.clamp(input - m_prevVal, -m_rateLimit * elapsedTime, m_rateLimit * elapsedTime);
+ m_prevTime = currentTime;
+ return m_prevVal;
+ }
+
+ /**
+ * Resets the slew rate limiter to the specified value; ignores the rate limit when doing so.
+ *
+ * @param value The value to reset to.
+ */
+ public void reset(double value) {
+ m_prevVal = value;
+ m_prevTime = WPIUtilJNI.now() * 1e-6;
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Pose2d.java b/wpimath/src/main/java/edu/wpi/first/math/geometry/Pose2d.java
similarity index 66%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Pose2d.java
rename to wpimath/src/main/java/edu/wpi/first/math/geometry/Pose2d.java
index a0e3b9a..6033b89 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Pose2d.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/geometry/Pose2d.java
@@ -1,22 +1,16 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
-
-import java.util.Objects;
+package edu.wpi.first.math.geometry;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.Objects;
-/**
- * Represents a 2d pose containing translational and rotational elements.
- */
+/** Represents a 2d pose containing translational and rotational elements. */
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonAutoDetect(getterVisibility = JsonAutoDetect.Visibility.NONE)
public class Pose2d {
@@ -24,8 +18,8 @@
private final Rotation2d m_rotation;
/**
- * Constructs a pose at the origin facing toward the positive X axis.
- * (Translation2d{0, 0} and Rotation{0})
+ * Constructs a pose at the origin facing toward the positive X axis. (Translation2d{0, 0} and
+ * Rotation{0})
*/
public Pose2d() {
m_translation = new Translation2d();
@@ -36,37 +30,34 @@
* Constructs a pose with the specified translation and rotation.
*
* @param translation The translational component of the pose.
- * @param rotation The rotational component of the pose.
+ * @param rotation The rotational component of the pose.
*/
@JsonCreator
- public Pose2d(@JsonProperty(required = true, value = "translation") Translation2d translation,
- @JsonProperty(required = true, value = "rotation") Rotation2d rotation) {
+ public Pose2d(
+ @JsonProperty(required = true, value = "translation") Translation2d translation,
+ @JsonProperty(required = true, value = "rotation") Rotation2d rotation) {
m_translation = translation;
m_rotation = rotation;
}
/**
- * Convenience constructors that takes in x and y values directly instead of
- * having to construct a Translation2d.
+ * Convenience constructors that takes in x and y values directly instead of having to construct a
+ * Translation2d.
*
- * @param x The x component of the translational component of the pose.
- * @param y The y component of the translational component of the pose.
+ * @param x The x component of the translational component of the pose.
+ * @param y The y component of the translational component of the pose.
* @param rotation The rotational component of the pose.
*/
- @SuppressWarnings("ParameterName")
public Pose2d(double x, double y, Rotation2d rotation) {
m_translation = new Translation2d(x, y);
m_rotation = rotation;
}
/**
- * Transforms the pose by the given transformation and returns the new
- * transformed pose.
+ * Transforms the pose by the given transformation and returns the new transformed pose.
*
- * <p>The matrix multiplication is as follows
- * [x_new] [cos, -sin, 0][transform.x]
- * [y_new] += [sin, cos, 0][transform.y]
- * [t_new] [0, 0, 1][transform.t]
+ * <p>The matrix multiplication is as follows [x_new] [cos, -sin, 0][transform.x] [y_new] += [sin,
+ * cos, 0][transform.y] [t_new] [0, 0, 1][transform.t]
*
* @param other The transform to transform the pose by.
* @return The transformed pose.
@@ -125,26 +116,26 @@
}
/**
- * Transforms the pose by the given transformation and returns the new pose.
- * See + operator for the matrix multiplication performed.
+ * Transforms the pose by the given transformation and returns the new pose. See + operator for
+ * the matrix multiplication performed.
*
* @param other The transform to transform the pose by.
* @return The transformed pose.
*/
public Pose2d transformBy(Transform2d other) {
- return new Pose2d(m_translation.plus(other.getTranslation().rotateBy(m_rotation)),
+ return new Pose2d(
+ m_translation.plus(other.getTranslation().rotateBy(m_rotation)),
m_rotation.plus(other.getRotation()));
}
/**
* Returns the other pose relative to the current pose.
*
- * <p>This function can often be used for trajectory tracking or pose
- * stabilization algorithms to get the error between the reference and the
- * current pose.
+ * <p>This function can often be used for trajectory tracking or pose stabilization algorithms to
+ * get the error between the reference and the current pose.
*
- * @param other The pose that is the origin of the new coordinate frame that
- * the current pose will be converted into.
+ * @param other The pose that is the origin of the new coordinate frame that the current pose will
+ * be converted into.
* @return The current pose relative to the new origin pose.
*/
public Pose2d relativeTo(Pose2d other) {
@@ -155,25 +146,23 @@
/**
* Obtain a new Pose2d from a (constant curvature) velocity.
*
- * <p>See <a href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">
- * Controls Engineering in the FIRST Robotics Competition</a> section 10.2 "Pose exponential" for
- * a derivation.
+ * <p>See <a href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">Controls
+ * Engineering in the FIRST Robotics Competition</a> section 10.2 "Pose exponential" for a
+ * derivation.
*
- * <p>The twist is a change in pose in the robot's coordinate frame since the
- * previous pose update. When the user runs exp() on the previous known
- * field-relative pose with the argument being the twist, the user will
- * receive the new field-relative pose.
+ * <p>The twist is a change in pose in the robot's coordinate frame since the previous pose
+ * update. When the user runs exp() on the previous known field-relative pose with the argument
+ * being the twist, the user will receive the new field-relative pose.
*
- * <p>"Exp" represents the pose exponential, which is solving a differential
- * equation moving the pose forward in time.
+ * <p>"Exp" represents the pose exponential, which is solving a differential equation moving the
+ * pose forward in time.
*
- * @param twist The change in pose in the robot's coordinate frame since the
- * previous pose update. For example, if a non-holonomic robot moves forward
- * 0.01 meters and changes angle by 0.5 degrees since the previous pose update,
- * the twist would be Twist2d{0.01, 0.0, toRadians(0.5)}
+ * @param twist The change in pose in the robot's coordinate frame since the previous pose update.
+ * For example, if a non-holonomic robot moves forward 0.01 meters and changes angle by 0.5
+ * degrees since the previous pose update, the twist would be Twist2d{0.01, 0.0,
+ * toRadians(0.5)}
* @return The new pose of the robot.
*/
- @SuppressWarnings("LocalVariableName")
public Pose2d exp(Twist2d twist) {
double dx = twist.dx;
double dy = twist.dy;
@@ -191,15 +180,17 @@
s = sinTheta / dtheta;
c = (1 - cosTheta) / dtheta;
}
- var transform = new Transform2d(new Translation2d(dx * s - dy * c, dx * c + dy * s),
- new Rotation2d(cosTheta, sinTheta));
+ var transform =
+ new Transform2d(
+ new Translation2d(dx * s - dy * c, dx * c + dy * s),
+ new Rotation2d(cosTheta, sinTheta));
return this.plus(transform);
}
/**
- * Returns a Twist2d that maps this pose to the end pose. If c is the output
- * of a.Log(b), then a.Exp(c) would yield b.
+ * Returns a Twist2d that maps this pose to the end pose. If c is the output of a.Log(b), then
+ * a.Exp(c) would yield b.
*
* @param end The end pose for the transformation.
* @return The twist that maps this to end.
@@ -218,9 +209,11 @@
halfThetaByTanOfHalfDtheta = -(halfDtheta * transform.getRotation().getSin()) / cosMinusOne;
}
- Translation2d translationPart = transform.getTranslation().rotateBy(
- new Rotation2d(halfThetaByTanOfHalfDtheta, -halfDtheta)
- ).times(Math.hypot(halfThetaByTanOfHalfDtheta, halfDtheta));
+ Translation2d translationPart =
+ transform
+ .getTranslation()
+ .rotateBy(new Rotation2d(halfThetaByTanOfHalfDtheta, -halfDtheta))
+ .times(Math.hypot(halfThetaByTanOfHalfDtheta, halfDtheta));
return new Twist2d(translationPart.getX(), translationPart.getY(), dtheta);
}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Rotation2d.java b/wpimath/src/main/java/edu/wpi/first/math/geometry/Rotation2d.java
similarity index 77%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Rotation2d.java
rename to wpimath/src/main/java/edu/wpi/first/math/geometry/Rotation2d.java
index e1c25eb..74ef228 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Rotation2d.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/geometry/Rotation2d.java
@@ -1,23 +1,16 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
-
-import java.util.Objects;
+package edu.wpi.first.math.geometry;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.Objects;
-/**
- * A rotation in a 2d coordinate frame represented a point on the unit circle
- * (cosine and sine).
- */
+/** A rotation in a 2d coordinate frame represented a point on the unit circle (cosine and sine). */
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonAutoDetect(getterVisibility = JsonAutoDetect.Visibility.NONE)
public class Rotation2d {
@@ -25,9 +18,7 @@
private final double m_cos;
private final double m_sin;
- /**
- * Constructs a Rotation2d with a default angle of 0 degrees.
- */
+ /** Constructs a Rotation2d with a default angle of 0 degrees. */
public Rotation2d() {
m_value = 0.0;
m_cos = 1.0;
@@ -35,8 +26,7 @@
}
/**
- * Constructs a Rotation2d with the given radian value.
- * The x and y don't have to be normalized.
+ * Constructs a Rotation2d with the given radian value. The x and y don't have to be normalized.
*
* @param value The value of the angle in radians.
*/
@@ -48,13 +38,11 @@
}
/**
- * Constructs a Rotation2d with the given x and y (cosine and sine)
- * components.
+ * Constructs a Rotation2d with the given x and y (cosine and sine) components.
*
* @param x The x component or cosine of the rotation.
* @param y The y component or sine of the rotation.
*/
- @SuppressWarnings("ParameterName")
public Rotation2d(double x, double y) {
double magnitude = Math.hypot(x, y);
if (magnitude > 1e-6) {
@@ -78,11 +66,9 @@
}
/**
- * Adds two rotations together, with the result being bounded between -pi and
- * pi.
+ * Adds two rotations together, with the result being bounded between -pi and pi.
*
- * <p>For example, Rotation2d.fromDegrees(30) + Rotation2d.fromDegrees(60) =
- * Rotation2d{-pi/2}
+ * <p>For example, Rotation2d.fromDegrees(30) + Rotation2d.fromDegrees(60) = Rotation2d{-pi/2}
*
* @param other The rotation to add.
* @return The sum of the two rotations.
@@ -92,11 +78,9 @@
}
/**
- * Subtracts the new rotation from the current rotation and returns the new
- * rotation.
+ * Subtracts the new rotation from the current rotation and returns the new rotation.
*
- * <p>For example, Rotation2d.fromDegrees(10) - Rotation2d.fromDegrees(100) =
- * Rotation2d{-pi/2}
+ * <p>For example, Rotation2d.fromDegrees(10) - Rotation2d.fromDegrees(100) = Rotation2d{-pi/2}
*
* @param other The rotation to subtract.
* @return The difference between the two rotations.
@@ -106,8 +90,8 @@
}
/**
- * Takes the inverse of the current rotation. This is simply the negative of
- * the current angular value.
+ * Takes the inverse of the current rotation. This is simply the negative of the current angular
+ * value.
*
* @return The inverse of the current rotation.
*/
@@ -129,18 +113,19 @@
* Adds the new rotation to the current rotation using a rotation matrix.
*
* <p>The matrix multiplication is as follows:
+ *
+ * <pre>
* [cos_new] [other.cos, -other.sin][cos]
* [sin_new] = [other.sin, other.cos][sin]
- * value_new = atan2(cos_new, sin_new)
+ * value_new = atan2(sin_new, cos_new)
+ * </pre>
*
* @param other The rotation to rotate by.
* @return The new rotated Rotation2d.
*/
public Rotation2d rotateBy(Rotation2d other) {
return new Rotation2d(
- m_cos * other.m_cos - m_sin * other.m_sin,
- m_cos * other.m_sin + m_sin * other.m_cos
- );
+ m_cos * other.m_cos - m_sin * other.m_sin, m_cos * other.m_sin + m_sin * other.m_cos);
}
/**
@@ -203,7 +188,8 @@
@Override
public boolean equals(Object obj) {
if (obj instanceof Rotation2d) {
- return Math.abs(((Rotation2d) obj).m_value - m_value) < 1E-9;
+ var other = (Rotation2d) obj;
+ return Math.hypot(m_cos - other.m_cos, m_sin - other.m_sin) < 1E-9;
}
return false;
}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Transform2d.java b/wpimath/src/main/java/edu/wpi/first/math/geometry/Transform2d.java
similarity index 76%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Transform2d.java
rename to wpimath/src/main/java/edu/wpi/first/math/geometry/Transform2d.java
index 16746d5..dd35670 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Transform2d.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/geometry/Transform2d.java
@@ -1,17 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
+package edu.wpi.first.math.geometry;
import java.util.Objects;
-/**
- * Represents a transformation for a Pose2d.
- */
+/** Represents a transformation for a Pose2d. */
public class Transform2d {
private final Translation2d m_translation;
private final Rotation2d m_rotation;
@@ -20,13 +15,15 @@
* Constructs the transform that maps the initial pose to the final pose.
*
* @param initial The initial pose for the transformation.
- * @param last The final pose for the transformation.
+ * @param last The final pose for the transformation.
*/
public Transform2d(Pose2d initial, Pose2d last) {
// We are rotating the difference between the translations
// using a clockwise rotation matrix. This transforms the global
// delta into a local delta (relative to the initial pose).
- m_translation = last.getTranslation().minus(initial.getTranslation())
+ m_translation =
+ last.getTranslation()
+ .minus(initial.getTranslation())
.rotateBy(initial.getRotation().unaryMinus());
m_rotation = last.getRotation().minus(initial.getRotation());
@@ -36,16 +33,14 @@
* Constructs a transform with the given translation and rotation components.
*
* @param translation Translational component of the transform.
- * @param rotation Rotational component of the transform.
+ * @param rotation Rotational component of the transform.
*/
public Transform2d(Translation2d translation, Rotation2d rotation) {
m_translation = translation;
m_rotation = rotation;
}
- /**
- * Constructs the identity transform -- maps an initial pose to itself.
- */
+ /** Constructs the identity transform -- maps an initial pose to itself. */
public Transform2d() {
m_translation = new Translation2d();
m_rotation = new Rotation2d();
@@ -62,6 +57,16 @@
}
/**
+ * Composes two transformations.
+ *
+ * @param other The transform to compose with this one.
+ * @return The composition of the two transformations.
+ */
+ public Transform2d plus(Transform2d other) {
+ return new Transform2d(new Pose2d(), new Pose2d().transformBy(this).transformBy(other));
+ }
+
+ /**
* Returns the translation component of the transformation.
*
* @return The translational component of the transform.
@@ -106,7 +111,8 @@
// We are rotating the difference between the translations
// using a clockwise rotation matrix. This transforms the global
// delta into a local delta (relative to the initial pose).
- return new Transform2d(getTranslation().unaryMinus().rotateBy(getRotation().unaryMinus()),
+ return new Transform2d(
+ getTranslation().unaryMinus().rotateBy(getRotation().unaryMinus()),
getRotation().unaryMinus());
}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Translation2d.java b/wpimath/src/main/java/edu/wpi/first/math/geometry/Translation2d.java
similarity index 71%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Translation2d.java
rename to wpimath/src/main/java/edu/wpi/first/math/geometry/Translation2d.java
index 5365759..251c078 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Translation2d.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/geometry/Translation2d.java
@@ -1,26 +1,21 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
-
-import java.util.Objects;
+package edu.wpi.first.math.geometry;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.Objects;
/**
- * Represents a translation in 2d space.
- * This object can be used to represent a point or a vector.
+ * Represents a translation in 2d space. This object can be used to represent a point or a vector.
*
- * <p>This assumes that you are using conventional mathematical axes.
- * When the robot is placed on the origin, facing toward the X direction,
- * moving forward increases the X, whereas moving to the left increases the Y.
+ * <p>This assumes that you are using conventional mathematical axes. When the robot is placed on
+ * the origin, facing toward the X direction, moving forward increases the X, whereas moving to the
+ * left increases the Y.
*/
@SuppressWarnings({"ParameterName", "MemberName"})
@JsonIgnoreProperties(ignoreUnknown = true)
@@ -29,33 +24,31 @@
private final double m_x;
private final double m_y;
- /**
- * Constructs a Translation2d with X and Y components equal to zero.
- */
+ /** Constructs a Translation2d with X and Y components equal to zero. */
public Translation2d() {
this(0.0, 0.0);
}
/**
- * Constructs a Translation2d with the X and Y components equal to the
- * provided values.
+ * Constructs a Translation2d with the X and Y components equal to the provided values.
*
* @param x The x component of the translation.
* @param y The y component of the translation.
*/
@JsonCreator
- public Translation2d(@JsonProperty(required = true, value = "x") double x,
- @JsonProperty(required = true, value = "y") double y) {
+ public Translation2d(
+ @JsonProperty(required = true, value = "x") double x,
+ @JsonProperty(required = true, value = "y") double y) {
m_x = x;
m_y = y;
}
/**
- * Constructs a Translation2d with the provided distance and angle. This is
- * essentially converting from polar coordinates to Cartesian coordinates.
+ * Constructs a Translation2d with the provided distance and angle. This is essentially converting
+ * from polar coordinates to Cartesian coordinates.
*
* @param distance The distance from the origin to the end of the translation.
- * @param angle The angle between the x-axis and the translation vector.
+ * @param angle The angle between the x-axis and the translation vector.
*/
public Translation2d(double distance, Rotation2d angle) {
m_x = distance * angle.getCos();
@@ -65,8 +58,8 @@
/**
* Calculates the distance between two translations in 2d space.
*
- * <p>This function uses the pythagorean theorem to calculate the distance.
- * distance = sqrt((x2 - x1)^2 + (y2 - y1)^2)
+ * <p>This function uses the pythagorean theorem to calculate the distance. distance = sqrt((x2 -
+ * x1)^2 + (y2 - y1)^2)
*
* @param other The translation to compute the distance to.
* @return The distance between the two translations.
@@ -107,30 +100,24 @@
/**
* Applies a rotation to the translation in 2d space.
*
- * <p>This multiplies the translation vector by a counterclockwise rotation
- * matrix of the given angle.
- * [x_new] [other.cos, -other.sin][x]
- * [y_new] = [other.sin, other.cos][y]
+ * <p>This multiplies the translation vector by a counterclockwise rotation matrix of the given
+ * angle. [x_new] [other.cos, -other.sin][x] [y_new] = [other.sin, other.cos][y]
*
- * <p>For example, rotating a Translation2d of {2, 0} by 90 degrees will return a
- * Translation2d of {0, 2}.
+ * <p>For example, rotating a Translation2d of {2, 0} by 90 degrees will return a Translation2d of
+ * {0, 2}.
*
* @param other The rotation to rotate the translation by.
* @return The new rotated translation.
*/
public Translation2d rotateBy(Rotation2d other) {
return new Translation2d(
- m_x * other.getCos() - m_y * other.getSin(),
- m_x * other.getSin() + m_y * other.getCos()
- );
+ m_x * other.getCos() - m_y * other.getSin(), m_x * other.getSin() + m_y * other.getCos());
}
/**
- * Adds two translations in 2d space and returns the sum. This is similar to
- * vector addition.
+ * Adds two translations in 2d space and returns the sum. This is similar to vector addition.
*
- * <p>For example, Translation2d{1.0, 2.5} + Translation2d{2.0, 5.5} =
- * Translation2d{3.0, 8.0}
+ * <p>For example, Translation2d{1.0, 2.5} + Translation2d{2.0, 5.5} = Translation2d{3.0, 8.0}
*
* @param other The translation to add.
* @return The sum of the translations.
@@ -140,11 +127,9 @@
}
/**
- * Subtracts the other translation from the other translation and returns the
- * difference.
+ * Subtracts the other translation from the other translation and returns the difference.
*
- * <p>For example, Translation2d{5.0, 4.0} - Translation2d{1.0, 2.0} =
- * Translation2d{4.0, 2.0}
+ * <p>For example, Translation2d{5.0, 4.0} - Translation2d{1.0, 2.0} = Translation2d{4.0, 2.0}
*
* @param other The translation to subtract.
* @return The difference between the two translations.
@@ -154,9 +139,8 @@
}
/**
- * Returns the inverse of the current translation. This is equivalent to
- * rotating by 180 degrees, flipping the point over both axes, or simply
- * negating both components of the translation.
+ * Returns the inverse of the current translation. This is equivalent to rotating by 180 degrees,
+ * flipping the point over both axes, or simply negating both components of the translation.
*
* @return The inverse of the current translation.
*/
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Twist2d.java b/wpimath/src/main/java/edu/wpi/first/math/geometry/Twist2d.java
similarity index 62%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Twist2d.java
rename to wpimath/src/main/java/edu/wpi/first/math/geometry/Twist2d.java
index 1482902..c73d236 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/geometry/Twist2d.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/geometry/Twist2d.java
@@ -1,43 +1,33 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
+package edu.wpi.first.math.geometry;
import java.util.Objects;
/**
- * A change in distance along arc since the last pose update. We can use ideas
- * from differential calculus to create new Pose2ds from a Twist2d and vise
- * versa.
+ * A change in distance along arc since the last pose update. We can use ideas from differential
+ * calculus to create new Pose2ds from a Twist2d and vise versa.
*
* <p>A Twist can be used to represent a difference between two poses.
*/
@SuppressWarnings("MemberName")
public class Twist2d {
- /**
- * Linear "dx" component.
- */
+ /** Linear "dx" component. */
public double dx;
- /**
- * Linear "dy" component.
- */
+ /** Linear "dy" component. */
public double dy;
- /**
- * Angular "dtheta" component (radians).
- */
+ /** Angular "dtheta" component (radians). */
public double dtheta;
- public Twist2d() {
- }
+ public Twist2d() {}
/**
* Constructs a Twist2d with the given values.
+ *
* @param dx Change in x direction relative to robot.
* @param dy Change in y direction relative to robot.
* @param dtheta Change in angle relative to robot.
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/ChassisSpeeds.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/ChassisSpeeds.java
new file mode 100644
index 0000000..451c008
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/ChassisSpeeds.java
@@ -0,0 +1,78 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.geometry.Rotation2d;
+
+/**
+ * Represents the speed of a robot chassis. Although this struct contains similar members compared
+ * to a Twist2d, they do NOT represent the same thing. Whereas a Twist2d represents a change in pose
+ * w.r.t to the robot frame of reference, this ChassisSpeeds struct represents a velocity w.r.t to
+ * the robot frame of reference.
+ *
+ * <p>A strictly non-holonomic drivetrain, such as a differential drive, should never have a dy
+ * component because it can never move sideways. Holonomic drivetrains such as swerve and mecanum
+ * will often have all three components.
+ */
+@SuppressWarnings("MemberName")
+public class ChassisSpeeds {
+ /** Represents forward velocity w.r.t the robot frame of reference. (Fwd is +) */
+ public double vxMetersPerSecond;
+
+ /** Represents sideways velocity w.r.t the robot frame of reference. (Left is +) */
+ public double vyMetersPerSecond;
+
+ /** Represents the angular velocity of the robot frame. (CCW is +) */
+ public double omegaRadiansPerSecond;
+
+ /** Constructs a ChassisSpeeds with zeros for dx, dy, and theta. */
+ public ChassisSpeeds() {}
+
+ /**
+ * Constructs a ChassisSpeeds object.
+ *
+ * @param vxMetersPerSecond Forward velocity.
+ * @param vyMetersPerSecond Sideways velocity.
+ * @param omegaRadiansPerSecond Angular velocity.
+ */
+ public ChassisSpeeds(
+ double vxMetersPerSecond, double vyMetersPerSecond, double omegaRadiansPerSecond) {
+ this.vxMetersPerSecond = vxMetersPerSecond;
+ this.vyMetersPerSecond = vyMetersPerSecond;
+ this.omegaRadiansPerSecond = omegaRadiansPerSecond;
+ }
+
+ /**
+ * Converts a user provided field-relative set of speeds into a robot-relative ChassisSpeeds
+ * object.
+ *
+ * @param vxMetersPerSecond The component of speed in the x direction relative to the field.
+ * Positive x is away from your alliance wall.
+ * @param vyMetersPerSecond The component of speed in the y direction relative to the field.
+ * Positive y is to your left when standing behind the alliance wall.
+ * @param omegaRadiansPerSecond The angular rate of the robot.
+ * @param robotAngle The angle of the robot as measured by a gyroscope. The robot's angle is
+ * considered to be zero when it is facing directly away from your alliance station wall.
+ * Remember that this should be CCW positive.
+ * @return ChassisSpeeds object representing the speeds in the robot's frame of reference.
+ */
+ public static ChassisSpeeds fromFieldRelativeSpeeds(
+ double vxMetersPerSecond,
+ double vyMetersPerSecond,
+ double omegaRadiansPerSecond,
+ Rotation2d robotAngle) {
+ return new ChassisSpeeds(
+ vxMetersPerSecond * robotAngle.getCos() + vyMetersPerSecond * robotAngle.getSin(),
+ -vxMetersPerSecond * robotAngle.getSin() + vyMetersPerSecond * robotAngle.getCos(),
+ omegaRadiansPerSecond);
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "ChassisSpeeds(Vx: %.2f m/s, Vy: %.2f m/s, Omega: %.2f rad/s)",
+ vxMetersPerSecond, vyMetersPerSecond, omegaRadiansPerSecond);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveKinematics.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveKinematics.java
new file mode 100644
index 0000000..7984e39
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveKinematics.java
@@ -0,0 +1,61 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+
+/**
+ * Helper class that converts a chassis velocity (dx and dtheta components) to left and right wheel
+ * velocities for a differential drive.
+ *
+ * <p>Inverse kinematics converts a desired chassis speed into left and right velocity components
+ * whereas forward kinematics converts left and right component velocities into a linear and angular
+ * chassis speed.
+ */
+@SuppressWarnings("MemberName")
+public class DifferentialDriveKinematics {
+ public final double trackWidthMeters;
+
+ /**
+ * Constructs a differential drive kinematics object.
+ *
+ * @param trackWidthMeters The track width of the drivetrain. Theoretically, this is the distance
+ * between the left wheels and right wheels. However, the empirical value may be larger than
+ * the physical measured value due to scrubbing effects.
+ */
+ public DifferentialDriveKinematics(double trackWidthMeters) {
+ this.trackWidthMeters = trackWidthMeters;
+ MathSharedStore.reportUsage(MathUsageId.kKinematics_DifferentialDrive, 1);
+ }
+
+ /**
+ * Returns a chassis speed from left and right component velocities using forward kinematics.
+ *
+ * @param wheelSpeeds The left and right velocities.
+ * @return The chassis speed.
+ */
+ public ChassisSpeeds toChassisSpeeds(DifferentialDriveWheelSpeeds wheelSpeeds) {
+ return new ChassisSpeeds(
+ (wheelSpeeds.leftMetersPerSecond + wheelSpeeds.rightMetersPerSecond) / 2,
+ 0,
+ (wheelSpeeds.rightMetersPerSecond - wheelSpeeds.leftMetersPerSecond) / trackWidthMeters);
+ }
+
+ /**
+ * Returns left and right component velocities from a chassis speed using inverse kinematics.
+ *
+ * @param chassisSpeeds The linear and angular (dx and dtheta) components that represent the
+ * chassis' speed.
+ * @return The left and right velocities.
+ */
+ public DifferentialDriveWheelSpeeds toWheelSpeeds(ChassisSpeeds chassisSpeeds) {
+ return new DifferentialDriveWheelSpeeds(
+ chassisSpeeds.vxMetersPerSecond
+ - trackWidthMeters / 2 * chassisSpeeds.omegaRadiansPerSecond,
+ chassisSpeeds.vxMetersPerSecond
+ + trackWidthMeters / 2 * chassisSpeeds.omegaRadiansPerSecond);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveOdometry.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveOdometry.java
new file mode 100644
index 0000000..0139573
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveOdometry.java
@@ -0,0 +1,113 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Twist2d;
+
+/**
+ * Class for differential drive odometry. Odometry allows you to track the robot's position on the
+ * field over the course of a match using readings from 2 encoders and a gyroscope.
+ *
+ * <p>Teams can use odometry during the autonomous period for complex tasks like path following.
+ * Furthermore, odometry can be used for latency compensation when using computer-vision systems.
+ *
+ * <p>It is important that you reset your encoders to zero before using this class. Any subsequent
+ * pose resets also require the encoders to be reset to zero.
+ */
+public class DifferentialDriveOdometry {
+ private Pose2d m_poseMeters;
+
+ private Rotation2d m_gyroOffset;
+ private Rotation2d m_previousAngle;
+
+ private double m_prevLeftDistance;
+ private double m_prevRightDistance;
+
+ /**
+ * Constructs a DifferentialDriveOdometry object.
+ *
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param initialPoseMeters The starting position of the robot on the field.
+ */
+ public DifferentialDriveOdometry(Rotation2d gyroAngle, Pose2d initialPoseMeters) {
+ m_poseMeters = initialPoseMeters;
+ m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
+ m_previousAngle = initialPoseMeters.getRotation();
+ MathSharedStore.reportUsage(MathUsageId.kOdometry_DifferentialDrive, 1);
+ }
+
+ /**
+ * Constructs a DifferentialDriveOdometry object with the default pose at the origin.
+ *
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public DifferentialDriveOdometry(Rotation2d gyroAngle) {
+ this(gyroAngle, new Pose2d());
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>You NEED to reset your encoders (to zero) when calling this method.
+ *
+ * <p>The gyroscope angle does not need to be reset here on the user's robot code. The library
+ * automatically takes care of offsetting the gyro angle.
+ *
+ * @param poseMeters The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
+ m_poseMeters = poseMeters;
+ m_previousAngle = poseMeters.getRotation();
+ m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
+
+ m_prevLeftDistance = 0.0;
+ m_prevRightDistance = 0.0;
+ }
+
+ /**
+ * Returns the position of the robot on the field.
+ *
+ * @return The pose of the robot (x and y are in meters).
+ */
+ public Pose2d getPoseMeters() {
+ return m_poseMeters;
+ }
+
+ /**
+ * Updates the robot position on the field using distance measurements from encoders. This method
+ * is more numerically accurate than using velocities to integrate the pose and is also
+ * advantageous for teams that are using lower CPR encoders.
+ *
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param leftDistanceMeters The distance traveled by the left encoder.
+ * @param rightDistanceMeters The distance traveled by the right encoder.
+ * @return The new pose of the robot.
+ */
+ public Pose2d update(
+ Rotation2d gyroAngle, double leftDistanceMeters, double rightDistanceMeters) {
+ double deltaLeftDistance = leftDistanceMeters - m_prevLeftDistance;
+ double deltaRightDistance = rightDistanceMeters - m_prevRightDistance;
+
+ m_prevLeftDistance = leftDistanceMeters;
+ m_prevRightDistance = rightDistanceMeters;
+
+ double averageDeltaDistance = (deltaLeftDistance + deltaRightDistance) / 2.0;
+ var angle = gyroAngle.plus(m_gyroOffset);
+
+ var newPose =
+ m_poseMeters.exp(
+ new Twist2d(averageDeltaDistance, 0.0, angle.minus(m_previousAngle).getRadians()));
+
+ m_previousAngle = angle;
+
+ m_poseMeters = new Pose2d(newPose.getTranslation(), angle);
+ return m_poseMeters;
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveWheelSpeeds.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveWheelSpeeds.java
new file mode 100644
index 0000000..b84eeba
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/DifferentialDriveWheelSpeeds.java
@@ -0,0 +1,55 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+/** Represents the wheel speeds for a differential drive drivetrain. */
+@SuppressWarnings("MemberName")
+public class DifferentialDriveWheelSpeeds {
+ /** Speed of the left side of the robot. */
+ public double leftMetersPerSecond;
+
+ /** Speed of the right side of the robot. */
+ public double rightMetersPerSecond;
+
+ /** Constructs a DifferentialDriveWheelSpeeds with zeros for left and right speeds. */
+ public DifferentialDriveWheelSpeeds() {}
+
+ /**
+ * Constructs a DifferentialDriveWheelSpeeds.
+ *
+ * @param leftMetersPerSecond The left speed.
+ * @param rightMetersPerSecond The right speed.
+ */
+ public DifferentialDriveWheelSpeeds(double leftMetersPerSecond, double rightMetersPerSecond) {
+ this.leftMetersPerSecond = leftMetersPerSecond;
+ this.rightMetersPerSecond = rightMetersPerSecond;
+ }
+
+ /**
+ * Normalizes the wheel speeds using some max attainable speed. Sometimes, after inverse
+ * kinematics, the requested speed from a/several modules may be above the max attainable speed
+ * for the driving motor on that module. To fix this issue, one can "normalize" all the wheel
+ * speeds to make sure that all requested module speeds are below the absolute threshold, while
+ * maintaining the ratio of speeds between modules.
+ *
+ * @param attainableMaxSpeedMetersPerSecond The absolute max speed that a wheel can reach.
+ */
+ public void normalize(double attainableMaxSpeedMetersPerSecond) {
+ double realMaxSpeed = Math.max(Math.abs(leftMetersPerSecond), Math.abs(rightMetersPerSecond));
+
+ if (realMaxSpeed > attainableMaxSpeedMetersPerSecond) {
+ leftMetersPerSecond = leftMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ rightMetersPerSecond =
+ rightMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "DifferentialDriveWheelSpeeds(Left: %.2f m/s, Right: %.2f m/s)",
+ leftMetersPerSecond, rightMetersPerSecond);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveKinematics.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveKinematics.java
new file mode 100644
index 0000000..3ea39e5
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveKinematics.java
@@ -0,0 +1,171 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.geometry.Translation2d;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * Helper class that converts a chassis velocity (dx, dy, and dtheta components) into individual
+ * wheel speeds.
+ *
+ * <p>The inverse kinematics (converting from a desired chassis velocity to individual wheel speeds)
+ * uses the relative locations of the wheels with respect to the center of rotation. The center of
+ * rotation for inverse kinematics is also variable. This means that you can set your set your
+ * center of rotation in a corner of the robot to perform special evasion maneuvers.
+ *
+ * <p>Forward kinematics (converting an array of wheel speeds into the overall chassis motion) is
+ * performs the exact opposite of what inverse kinematics does. Since this is an overdetermined
+ * system (more equations than variables), we use a least-squares approximation.
+ *
+ * <p>The inverse kinematics: [wheelSpeeds] = [wheelLocations] * [chassisSpeeds] We take the
+ * Moore-Penrose pseudoinverse of [wheelLocations] and then multiply by [wheelSpeeds] to get our
+ * chassis speeds.
+ *
+ * <p>Forward kinematics is also used for odometry -- determining the position of the robot on the
+ * field using encoders and a gyro.
+ */
+public class MecanumDriveKinematics {
+ private final SimpleMatrix m_inverseKinematics;
+ private final SimpleMatrix m_forwardKinematics;
+
+ private final Translation2d m_frontLeftWheelMeters;
+ private final Translation2d m_frontRightWheelMeters;
+ private final Translation2d m_rearLeftWheelMeters;
+ private final Translation2d m_rearRightWheelMeters;
+
+ private Translation2d m_prevCoR = new Translation2d();
+
+ /**
+ * Constructs a mecanum drive kinematics object.
+ *
+ * @param frontLeftWheelMeters The location of the front-left wheel relative to the physical
+ * center of the robot.
+ * @param frontRightWheelMeters The location of the front-right wheel relative to the physical
+ * center of the robot.
+ * @param rearLeftWheelMeters The location of the rear-left wheel relative to the physical center
+ * of the robot.
+ * @param rearRightWheelMeters The location of the rear-right wheel relative to the physical
+ * center of the robot.
+ */
+ public MecanumDriveKinematics(
+ Translation2d frontLeftWheelMeters,
+ Translation2d frontRightWheelMeters,
+ Translation2d rearLeftWheelMeters,
+ Translation2d rearRightWheelMeters) {
+ m_frontLeftWheelMeters = frontLeftWheelMeters;
+ m_frontRightWheelMeters = frontRightWheelMeters;
+ m_rearLeftWheelMeters = rearLeftWheelMeters;
+ m_rearRightWheelMeters = rearRightWheelMeters;
+
+ m_inverseKinematics = new SimpleMatrix(4, 3);
+
+ setInverseKinematics(
+ frontLeftWheelMeters, frontRightWheelMeters, rearLeftWheelMeters, rearRightWheelMeters);
+ m_forwardKinematics = m_inverseKinematics.pseudoInverse();
+
+ MathSharedStore.reportUsage(MathUsageId.kKinematics_MecanumDrive, 1);
+ }
+
+ /**
+ * Performs inverse kinematics to return the wheel speeds from a desired chassis velocity. This
+ * method is often used to convert joystick values into wheel speeds.
+ *
+ * <p>This function also supports variable centers of rotation. During normal operations, the
+ * center of rotation is usually the same as the physical center of the robot; therefore, the
+ * argument is defaulted to that use case. However, if you wish to change the center of rotation
+ * for evasive maneuvers, vision alignment, or for any other use case, you can do so.
+ *
+ * @param chassisSpeeds The desired chassis speed.
+ * @param centerOfRotationMeters The center of rotation. For example, if you set the center of
+ * rotation at one corner of the robot and provide a chassis speed that only has a dtheta
+ * component, the robot will rotate around that corner.
+ * @return The wheel speeds. Use caution because they are not normalized. Sometimes, a user input
+ * may cause one of the wheel speeds to go above the attainable max velocity. Use the {@link
+ * MecanumDriveWheelSpeeds#normalize(double)} function to rectify this issue.
+ */
+ public MecanumDriveWheelSpeeds toWheelSpeeds(
+ ChassisSpeeds chassisSpeeds, Translation2d centerOfRotationMeters) {
+ // We have a new center of rotation. We need to compute the matrix again.
+ if (!centerOfRotationMeters.equals(m_prevCoR)) {
+ var fl = m_frontLeftWheelMeters.minus(centerOfRotationMeters);
+ var fr = m_frontRightWheelMeters.minus(centerOfRotationMeters);
+ var rl = m_rearLeftWheelMeters.minus(centerOfRotationMeters);
+ var rr = m_rearRightWheelMeters.minus(centerOfRotationMeters);
+
+ setInverseKinematics(fl, fr, rl, rr);
+ m_prevCoR = centerOfRotationMeters;
+ }
+
+ var chassisSpeedsVector = new SimpleMatrix(3, 1);
+ chassisSpeedsVector.setColumn(
+ 0,
+ 0,
+ chassisSpeeds.vxMetersPerSecond,
+ chassisSpeeds.vyMetersPerSecond,
+ chassisSpeeds.omegaRadiansPerSecond);
+
+ var wheelsVector = m_inverseKinematics.mult(chassisSpeedsVector);
+ return new MecanumDriveWheelSpeeds(
+ wheelsVector.get(0, 0),
+ wheelsVector.get(1, 0),
+ wheelsVector.get(2, 0),
+ wheelsVector.get(3, 0));
+ }
+
+ /**
+ * Performs inverse kinematics. See {@link #toWheelSpeeds(ChassisSpeeds, Translation2d)} for more
+ * information.
+ *
+ * @param chassisSpeeds The desired chassis speed.
+ * @return The wheel speeds.
+ */
+ public MecanumDriveWheelSpeeds toWheelSpeeds(ChassisSpeeds chassisSpeeds) {
+ return toWheelSpeeds(chassisSpeeds, new Translation2d());
+ }
+
+ /**
+ * Performs forward kinematics to return the resulting chassis state from the given wheel speeds.
+ * This method is often used for odometry -- determining the robot's position on the field using
+ * data from the real-world speed of each wheel on the robot.
+ *
+ * @param wheelSpeeds The current mecanum drive wheel speeds.
+ * @return The resulting chassis speed.
+ */
+ public ChassisSpeeds toChassisSpeeds(MecanumDriveWheelSpeeds wheelSpeeds) {
+ var wheelSpeedsVector = new SimpleMatrix(4, 1);
+ wheelSpeedsVector.setColumn(
+ 0,
+ 0,
+ wheelSpeeds.frontLeftMetersPerSecond,
+ wheelSpeeds.frontRightMetersPerSecond,
+ wheelSpeeds.rearLeftMetersPerSecond,
+ wheelSpeeds.rearRightMetersPerSecond);
+ var chassisSpeedsVector = m_forwardKinematics.mult(wheelSpeedsVector);
+
+ return new ChassisSpeeds(
+ chassisSpeedsVector.get(0, 0),
+ chassisSpeedsVector.get(1, 0),
+ chassisSpeedsVector.get(2, 0));
+ }
+
+ /**
+ * Construct inverse kinematics matrix from wheel locations.
+ *
+ * @param fl The location of the front-left wheel relative to the physical center of the robot.
+ * @param fr The location of the front-right wheel relative to the physical center of the robot.
+ * @param rl The location of the rear-left wheel relative to the physical center of the robot.
+ * @param rr The location of the rear-right wheel relative to the physical center of the robot.
+ */
+ private void setInverseKinematics(
+ Translation2d fl, Translation2d fr, Translation2d rl, Translation2d rr) {
+ m_inverseKinematics.setRow(0, 0, 1, -1, -(fl.getX() + fl.getY()));
+ m_inverseKinematics.setRow(1, 0, 1, 1, fr.getX() - fr.getY());
+ m_inverseKinematics.setRow(2, 0, 1, 1, rl.getX() - rl.getY());
+ m_inverseKinematics.setRow(3, 0, 1, -1, -(rr.getX() + rr.getY()));
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveMotorVoltages.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveMotorVoltages.java
new file mode 100644
index 0000000..b504acc
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveMotorVoltages.java
@@ -0,0 +1,51 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+/** Represents the motor voltages for a mecanum drive drivetrain. */
+@SuppressWarnings("MemberName")
+public class MecanumDriveMotorVoltages {
+ /** Voltage of the front left motor. */
+ public double frontLeftVoltage;
+
+ /** Voltage of the front right motor. */
+ public double frontRightVoltage;
+
+ /** Voltage of the rear left motor. */
+ public double rearLeftVoltage;
+
+ /** Voltage of the rear right motor. */
+ public double rearRightVoltage;
+
+ /** Constructs a MecanumDriveMotorVoltages with zeros for all member fields. */
+ public MecanumDriveMotorVoltages() {}
+
+ /**
+ * Constructs a MecanumDriveMotorVoltages.
+ *
+ * @param frontLeftVoltage Voltage of the front left motor.
+ * @param frontRightVoltage Voltage of the front right motor.
+ * @param rearLeftVoltage Voltage of the rear left motor.
+ * @param rearRightVoltage Voltage of the rear right motor.
+ */
+ public MecanumDriveMotorVoltages(
+ double frontLeftVoltage,
+ double frontRightVoltage,
+ double rearLeftVoltage,
+ double rearRightVoltage) {
+ this.frontLeftVoltage = frontLeftVoltage;
+ this.frontRightVoltage = frontRightVoltage;
+ this.rearLeftVoltage = rearLeftVoltage;
+ this.rearRightVoltage = rearRightVoltage;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "MecanumDriveMotorVoltages(Front Left: %.2f V, Front Right: %.2f V, "
+ + "Rear Left: %.2f V, Rear Right: %.2f V)",
+ frontLeftVoltage, frontRightVoltage, rearLeftVoltage, rearRightVoltage);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveOdometry.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveOdometry.java
new file mode 100644
index 0000000..b3c79c9
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveOdometry.java
@@ -0,0 +1,125 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Twist2d;
+import edu.wpi.first.util.WPIUtilJNI;
+
+/**
+ * Class for mecanum drive odometry. Odometry allows you to track the robot's position on the field
+ * over a course of a match using readings from your mecanum wheel encoders.
+ *
+ * <p>Teams can use odometry during the autonomous period for complex tasks like path following.
+ * Furthermore, odometry can be used for latency compensation when using computer-vision systems.
+ */
+public class MecanumDriveOdometry {
+ private final MecanumDriveKinematics m_kinematics;
+ private Pose2d m_poseMeters;
+ private double m_prevTimeSeconds = -1;
+
+ private Rotation2d m_gyroOffset;
+ private Rotation2d m_previousAngle;
+
+ /**
+ * Constructs a MecanumDriveOdometry object.
+ *
+ * @param kinematics The mecanum drive kinematics for your drivetrain.
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param initialPoseMeters The starting position of the robot on the field.
+ */
+ public MecanumDriveOdometry(
+ MecanumDriveKinematics kinematics, Rotation2d gyroAngle, Pose2d initialPoseMeters) {
+ m_kinematics = kinematics;
+ m_poseMeters = initialPoseMeters;
+ m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
+ m_previousAngle = initialPoseMeters.getRotation();
+ MathSharedStore.reportUsage(MathUsageId.kOdometry_MecanumDrive, 1);
+ }
+
+ /**
+ * Constructs a MecanumDriveOdometry object with the default pose at the origin.
+ *
+ * @param kinematics The mecanum drive kinematics for your drivetrain.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public MecanumDriveOdometry(MecanumDriveKinematics kinematics, Rotation2d gyroAngle) {
+ this(kinematics, gyroAngle, new Pose2d());
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>The gyroscope angle does not need to be reset here on the user's robot code. The library
+ * automatically takes care of offsetting the gyro angle.
+ *
+ * @param poseMeters The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
+ m_poseMeters = poseMeters;
+ m_previousAngle = poseMeters.getRotation();
+ m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
+ }
+
+ /**
+ * Returns the position of the robot on the field.
+ *
+ * @return The pose of the robot (x and y are in meters).
+ */
+ public Pose2d getPoseMeters() {
+ return m_poseMeters;
+ }
+
+ /**
+ * Updates the robot's position on the field using forward kinematics and integration of the pose
+ * over time. This method takes in the current time as a parameter to calculate period (difference
+ * between two timestamps). The period is used to calculate the change in distance from a
+ * velocity. This also takes in an angle parameter which is used instead of the angular rate that
+ * is calculated from forward kinematics.
+ *
+ * @param currentTimeSeconds The current time in seconds.
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param wheelSpeeds The current wheel speeds.
+ * @return The new pose of the robot.
+ */
+ public Pose2d updateWithTime(
+ double currentTimeSeconds, Rotation2d gyroAngle, MecanumDriveWheelSpeeds wheelSpeeds) {
+ double period = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : 0.0;
+ m_prevTimeSeconds = currentTimeSeconds;
+
+ var angle = gyroAngle.plus(m_gyroOffset);
+
+ var chassisState = m_kinematics.toChassisSpeeds(wheelSpeeds);
+ var newPose =
+ m_poseMeters.exp(
+ new Twist2d(
+ chassisState.vxMetersPerSecond * period,
+ chassisState.vyMetersPerSecond * period,
+ angle.minus(m_previousAngle).getRadians()));
+
+ m_previousAngle = angle;
+ m_poseMeters = new Pose2d(newPose.getTranslation(), angle);
+ return m_poseMeters;
+ }
+
+ /**
+ * Updates the robot's position on the field using forward kinematics and integration of the pose
+ * over time. This method automatically calculates the current time to calculate period
+ * (difference between two timestamps). The period is used to calculate the change in distance
+ * from a velocity. This also takes in an angle parameter which is used instead of the angular
+ * rate that is calculated from forward kinematics.
+ *
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param wheelSpeeds The current wheel speeds.
+ * @return The new pose of the robot.
+ */
+ public Pose2d update(Rotation2d gyroAngle, MecanumDriveWheelSpeeds wheelSpeeds) {
+ return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, wheelSpeeds);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveWheelSpeeds.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveWheelSpeeds.java
new file mode 100644
index 0000000..7a159fe
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/MecanumDriveWheelSpeeds.java
@@ -0,0 +1,86 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import java.util.stream.DoubleStream;
+
+@SuppressWarnings("MemberName")
+public class MecanumDriveWheelSpeeds {
+ /** Speed of the front left wheel. */
+ public double frontLeftMetersPerSecond;
+
+ /** Speed of the front right wheel. */
+ public double frontRightMetersPerSecond;
+
+ /** Speed of the rear left wheel. */
+ public double rearLeftMetersPerSecond;
+
+ /** Speed of the rear right wheel. */
+ public double rearRightMetersPerSecond;
+
+ /** Constructs a MecanumDriveWheelSpeeds with zeros for all member fields. */
+ public MecanumDriveWheelSpeeds() {}
+
+ /**
+ * Constructs a MecanumDriveWheelSpeeds.
+ *
+ * @param frontLeftMetersPerSecond Speed of the front left wheel.
+ * @param frontRightMetersPerSecond Speed of the front right wheel.
+ * @param rearLeftMetersPerSecond Speed of the rear left wheel.
+ * @param rearRightMetersPerSecond Speed of the rear right wheel.
+ */
+ public MecanumDriveWheelSpeeds(
+ double frontLeftMetersPerSecond,
+ double frontRightMetersPerSecond,
+ double rearLeftMetersPerSecond,
+ double rearRightMetersPerSecond) {
+ this.frontLeftMetersPerSecond = frontLeftMetersPerSecond;
+ this.frontRightMetersPerSecond = frontRightMetersPerSecond;
+ this.rearLeftMetersPerSecond = rearLeftMetersPerSecond;
+ this.rearRightMetersPerSecond = rearRightMetersPerSecond;
+ }
+
+ /**
+ * Normalizes the wheel speeds using some max attainable speed. Sometimes, after inverse
+ * kinematics, the requested speed from a/several modules may be above the max attainable speed
+ * for the driving motor on that module. To fix this issue, one can "normalize" all the wheel
+ * speeds to make sure that all requested module speeds are below the absolute threshold, while
+ * maintaining the ratio of speeds between modules.
+ *
+ * @param attainableMaxSpeedMetersPerSecond The absolute max speed that a wheel can reach.
+ */
+ public void normalize(double attainableMaxSpeedMetersPerSecond) {
+ double realMaxSpeed =
+ DoubleStream.of(
+ frontLeftMetersPerSecond,
+ frontRightMetersPerSecond,
+ rearLeftMetersPerSecond,
+ rearRightMetersPerSecond)
+ .max()
+ .getAsDouble();
+
+ if (realMaxSpeed > attainableMaxSpeedMetersPerSecond) {
+ frontLeftMetersPerSecond =
+ frontLeftMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ frontRightMetersPerSecond =
+ frontRightMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ rearLeftMetersPerSecond =
+ rearLeftMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ rearRightMetersPerSecond =
+ rearRightMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "MecanumDriveWheelSpeeds(Front Left: %.2f m/s, Front Right: %.2f m/s, "
+ + "Rear Left: %.2f m/s, Rear Right: %.2f m/s)",
+ frontLeftMetersPerSecond,
+ frontRightMetersPerSecond,
+ rearLeftMetersPerSecond,
+ rearRightMetersPerSecond);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveDriveKinematics.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveDriveKinematics.java
new file mode 100644
index 0000000..4c6d9cc
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveDriveKinematics.java
@@ -0,0 +1,194 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import java.util.Arrays;
+import java.util.Collections;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * Helper class that converts a chassis velocity (dx, dy, and dtheta components) into individual
+ * module states (speed and angle).
+ *
+ * <p>The inverse kinematics (converting from a desired chassis velocity to individual module
+ * states) uses the relative locations of the modules with respect to the center of rotation. The
+ * center of rotation for inverse kinematics is also variable. This means that you can set your set
+ * your center of rotation in a corner of the robot to perform special evasion maneuvers.
+ *
+ * <p>Forward kinematics (converting an array of module states into the overall chassis motion) is
+ * performs the exact opposite of what inverse kinematics does. Since this is an overdetermined
+ * system (more equations than variables), we use a least-squares approximation.
+ *
+ * <p>The inverse kinematics: [moduleStates] = [moduleLocations] * [chassisSpeeds] We take the
+ * Moore-Penrose pseudoinverse of [moduleLocations] and then multiply by [moduleStates] to get our
+ * chassis speeds.
+ *
+ * <p>Forward kinematics is also used for odometry -- determining the position of the robot on the
+ * field using encoders and a gyro.
+ */
+public class SwerveDriveKinematics {
+ private final SimpleMatrix m_inverseKinematics;
+ private final SimpleMatrix m_forwardKinematics;
+
+ private final int m_numModules;
+ private final Translation2d[] m_modules;
+ private Translation2d m_prevCoR = new Translation2d();
+
+ /**
+ * Constructs a swerve drive kinematics object. This takes in a variable number of wheel locations
+ * as Translation2ds. The order in which you pass in the wheel locations is the same order that
+ * you will receive the module states when performing inverse kinematics. It is also expected that
+ * you pass in the module states in the same order when calling the forward kinematics methods.
+ *
+ * @param wheelsMeters The locations of the wheels relative to the physical center of the robot.
+ */
+ public SwerveDriveKinematics(Translation2d... wheelsMeters) {
+ if (wheelsMeters.length < 2) {
+ throw new IllegalArgumentException("A swerve drive requires at least two modules");
+ }
+ m_numModules = wheelsMeters.length;
+ m_modules = Arrays.copyOf(wheelsMeters, m_numModules);
+ m_inverseKinematics = new SimpleMatrix(m_numModules * 2, 3);
+
+ for (int i = 0; i < m_numModules; i++) {
+ m_inverseKinematics.setRow(i * 2 + 0, 0, /* Start Data */ 1, 0, -m_modules[i].getY());
+ m_inverseKinematics.setRow(i * 2 + 1, 0, /* Start Data */ 0, 1, +m_modules[i].getX());
+ }
+ m_forwardKinematics = m_inverseKinematics.pseudoInverse();
+
+ MathSharedStore.reportUsage(MathUsageId.kKinematics_SwerveDrive, 1);
+ }
+
+ /**
+ * Performs inverse kinematics to return the module states from a desired chassis velocity. This
+ * method is often used to convert joystick values into module speeds and angles.
+ *
+ * <p>This function also supports variable centers of rotation. During normal operations, the
+ * center of rotation is usually the same as the physical center of the robot; therefore, the
+ * argument is defaulted to that use case. However, if you wish to change the center of rotation
+ * for evasive maneuvers, vision alignment, or for any other use case, you can do so.
+ *
+ * @param chassisSpeeds The desired chassis speed.
+ * @param centerOfRotationMeters The center of rotation. For example, if you set the center of
+ * rotation at one corner of the robot and provide a chassis speed that only has a dtheta
+ * component, the robot will rotate around that corner.
+ * @return An array containing the module states. Use caution because these module states are not
+ * normalized. Sometimes, a user input may cause one of the module speeds to go above the
+ * attainable max velocity. Use the {@link #normalizeWheelSpeeds(SwerveModuleState[], double)
+ * normalizeWheelSpeeds} function to rectify this issue.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public SwerveModuleState[] toSwerveModuleStates(
+ ChassisSpeeds chassisSpeeds, Translation2d centerOfRotationMeters) {
+ if (!centerOfRotationMeters.equals(m_prevCoR)) {
+ for (int i = 0; i < m_numModules; i++) {
+ m_inverseKinematics.setRow(
+ i * 2 + 0,
+ 0, /* Start Data */
+ 1,
+ 0,
+ -m_modules[i].getY() + centerOfRotationMeters.getY());
+ m_inverseKinematics.setRow(
+ i * 2 + 1,
+ 0, /* Start Data */
+ 0,
+ 1,
+ +m_modules[i].getX() - centerOfRotationMeters.getX());
+ }
+ m_prevCoR = centerOfRotationMeters;
+ }
+
+ var chassisSpeedsVector = new SimpleMatrix(3, 1);
+ chassisSpeedsVector.setColumn(
+ 0,
+ 0,
+ chassisSpeeds.vxMetersPerSecond,
+ chassisSpeeds.vyMetersPerSecond,
+ chassisSpeeds.omegaRadiansPerSecond);
+
+ var moduleStatesMatrix = m_inverseKinematics.mult(chassisSpeedsVector);
+ SwerveModuleState[] moduleStates = new SwerveModuleState[m_numModules];
+
+ for (int i = 0; i < m_numModules; i++) {
+ double x = moduleStatesMatrix.get(i * 2, 0);
+ double y = moduleStatesMatrix.get(i * 2 + 1, 0);
+
+ double speed = Math.hypot(x, y);
+ Rotation2d angle = new Rotation2d(x, y);
+
+ moduleStates[i] = new SwerveModuleState(speed, angle);
+ }
+
+ return moduleStates;
+ }
+
+ /**
+ * Performs inverse kinematics. See {@link #toSwerveModuleStates(ChassisSpeeds, Translation2d)}
+ * toSwerveModuleStates for more information.
+ *
+ * @param chassisSpeeds The desired chassis speed.
+ * @return An array containing the module states.
+ */
+ public SwerveModuleState[] toSwerveModuleStates(ChassisSpeeds chassisSpeeds) {
+ return toSwerveModuleStates(chassisSpeeds, new Translation2d());
+ }
+
+ /**
+ * Performs forward kinematics to return the resulting chassis state from the given module states.
+ * This method is often used for odometry -- determining the robot's position on the field using
+ * data from the real-world speed and angle of each module on the robot.
+ *
+ * @param wheelStates The state of the modules (as a SwerveModuleState type) as measured from
+ * respective encoders and gyros. The order of the swerve module states should be same as
+ * passed into the constructor of this class.
+ * @return The resulting chassis speed.
+ */
+ public ChassisSpeeds toChassisSpeeds(SwerveModuleState... wheelStates) {
+ if (wheelStates.length != m_numModules) {
+ throw new IllegalArgumentException(
+ "Number of modules is not consistent with number of wheel locations provided in "
+ + "constructor");
+ }
+ var moduleStatesMatrix = new SimpleMatrix(m_numModules * 2, 1);
+
+ for (int i = 0; i < m_numModules; i++) {
+ var module = wheelStates[i];
+ moduleStatesMatrix.set(i * 2, 0, module.speedMetersPerSecond * module.angle.getCos());
+ moduleStatesMatrix.set(i * 2 + 1, module.speedMetersPerSecond * module.angle.getSin());
+ }
+
+ var chassisSpeedsVector = m_forwardKinematics.mult(moduleStatesMatrix);
+ return new ChassisSpeeds(
+ chassisSpeedsVector.get(0, 0),
+ chassisSpeedsVector.get(1, 0),
+ chassisSpeedsVector.get(2, 0));
+ }
+
+ /**
+ * Normalizes the wheel speeds using some max attainable speed. Sometimes, after inverse
+ * kinematics, the requested speed from a/several modules may be above the max attainable speed
+ * for the driving motor on that module. To fix this issue, one can "normalize" all the wheel
+ * speeds to make sure that all requested module speeds are below the absolute threshold, while
+ * maintaining the ratio of speeds between modules.
+ *
+ * @param moduleStates Reference to array of module states. The array will be mutated with the
+ * normalized speeds!
+ * @param attainableMaxSpeedMetersPerSecond The absolute max speed that a module can reach.
+ */
+ public static void normalizeWheelSpeeds(
+ SwerveModuleState[] moduleStates, double attainableMaxSpeedMetersPerSecond) {
+ double realMaxSpeed = Collections.max(Arrays.asList(moduleStates)).speedMetersPerSecond;
+ if (realMaxSpeed > attainableMaxSpeedMetersPerSecond) {
+ for (SwerveModuleState moduleState : moduleStates) {
+ moduleState.speedMetersPerSecond =
+ moduleState.speedMetersPerSecond / realMaxSpeed * attainableMaxSpeedMetersPerSecond;
+ }
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveDriveOdometry.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveDriveOdometry.java
new file mode 100644
index 0000000..8b4161e
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveDriveOdometry.java
@@ -0,0 +1,129 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.MathUsageId;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Twist2d;
+import edu.wpi.first.util.WPIUtilJNI;
+
+/**
+ * Class for swerve drive odometry. Odometry allows you to track the robot's position on the field
+ * over a course of a match using readings from your swerve drive encoders and swerve azimuth
+ * encoders.
+ *
+ * <p>Teams can use odometry during the autonomous period for complex tasks like path following.
+ * Furthermore, odometry can be used for latency compensation when using computer-vision systems.
+ */
+public class SwerveDriveOdometry {
+ private final SwerveDriveKinematics m_kinematics;
+ private Pose2d m_poseMeters;
+ private double m_prevTimeSeconds = -1;
+
+ private Rotation2d m_gyroOffset;
+ private Rotation2d m_previousAngle;
+
+ /**
+ * Constructs a SwerveDriveOdometry object.
+ *
+ * @param kinematics The swerve drive kinematics for your drivetrain.
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param initialPose The starting position of the robot on the field.
+ */
+ public SwerveDriveOdometry(
+ SwerveDriveKinematics kinematics, Rotation2d gyroAngle, Pose2d initialPose) {
+ m_kinematics = kinematics;
+ m_poseMeters = initialPose;
+ m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
+ m_previousAngle = initialPose.getRotation();
+ MathSharedStore.reportUsage(MathUsageId.kOdometry_SwerveDrive, 1);
+ }
+
+ /**
+ * Constructs a SwerveDriveOdometry object with the default pose at the origin.
+ *
+ * @param kinematics The swerve drive kinematics for your drivetrain.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public SwerveDriveOdometry(SwerveDriveKinematics kinematics, Rotation2d gyroAngle) {
+ this(kinematics, gyroAngle, new Pose2d());
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>The gyroscope angle does not need to be reset here on the user's robot code. The library
+ * automatically takes care of offsetting the gyro angle.
+ *
+ * @param pose The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ public void resetPosition(Pose2d pose, Rotation2d gyroAngle) {
+ m_poseMeters = pose;
+ m_previousAngle = pose.getRotation();
+ m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
+ }
+
+ /**
+ * Returns the position of the robot on the field.
+ *
+ * @return The pose of the robot (x and y are in meters).
+ */
+ public Pose2d getPoseMeters() {
+ return m_poseMeters;
+ }
+
+ /**
+ * Updates the robot's position on the field using forward kinematics and integration of the pose
+ * over time. This method takes in the current time as a parameter to calculate period (difference
+ * between two timestamps). The period is used to calculate the change in distance from a
+ * velocity. This also takes in an angle parameter which is used instead of the angular rate that
+ * is calculated from forward kinematics.
+ *
+ * @param currentTimeSeconds The current time in seconds.
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param moduleStates The current state of all swerve modules. Please provide the states in the
+ * same order in which you instantiated your SwerveDriveKinematics.
+ * @return The new pose of the robot.
+ */
+ public Pose2d updateWithTime(
+ double currentTimeSeconds, Rotation2d gyroAngle, SwerveModuleState... moduleStates) {
+ double period = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : 0.0;
+ m_prevTimeSeconds = currentTimeSeconds;
+
+ var angle = gyroAngle.plus(m_gyroOffset);
+
+ var chassisState = m_kinematics.toChassisSpeeds(moduleStates);
+ var newPose =
+ m_poseMeters.exp(
+ new Twist2d(
+ chassisState.vxMetersPerSecond * period,
+ chassisState.vyMetersPerSecond * period,
+ angle.minus(m_previousAngle).getRadians()));
+
+ m_previousAngle = angle;
+ m_poseMeters = new Pose2d(newPose.getTranslation(), angle);
+
+ return m_poseMeters;
+ }
+
+ /**
+ * Updates the robot's position on the field using forward kinematics and integration of the pose
+ * over time. This method automatically calculates the current time to calculate period
+ * (difference between two timestamps). The period is used to calculate the change in distance
+ * from a velocity. This also takes in an angle parameter which is used instead of the angular
+ * rate that is calculated from forward kinematics.
+ *
+ * @param gyroAngle The angle reported by the gyroscope.
+ * @param moduleStates The current state of all swerve modules. Please provide the states in the
+ * same order in which you instantiated your SwerveDriveKinematics.
+ * @return The new pose of the robot.
+ */
+ public Pose2d update(Rotation2d gyroAngle, SwerveModuleState... moduleStates) {
+ return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, moduleStates);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveModuleState.java b/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveModuleState.java
new file mode 100644
index 0000000..6a9c48c
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/kinematics/SwerveModuleState.java
@@ -0,0 +1,85 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import edu.wpi.first.math.geometry.Rotation2d;
+import java.util.Objects;
+
+/** Represents the state of one swerve module. */
+@SuppressWarnings("MemberName")
+public class SwerveModuleState implements Comparable<SwerveModuleState> {
+ /** Speed of the wheel of the module. */
+ public double speedMetersPerSecond;
+
+ /** Angle of the module. */
+ public Rotation2d angle = Rotation2d.fromDegrees(0);
+
+ /** Constructs a SwerveModuleState with zeros for speed and angle. */
+ public SwerveModuleState() {}
+
+ /**
+ * Constructs a SwerveModuleState.
+ *
+ * @param speedMetersPerSecond The speed of the wheel of the module.
+ * @param angle The angle of the module.
+ */
+ public SwerveModuleState(double speedMetersPerSecond, Rotation2d angle) {
+ this.speedMetersPerSecond = speedMetersPerSecond;
+ this.angle = angle;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof SwerveModuleState) {
+ return Double.compare(speedMetersPerSecond, ((SwerveModuleState) obj).speedMetersPerSecond)
+ == 0;
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(speedMetersPerSecond);
+ }
+
+ /**
+ * Compares two swerve module states. One swerve module is "greater" than the other if its speed
+ * is higher than the other.
+ *
+ * @param other The other swerve module.
+ * @return 1 if this is greater, 0 if both are equal, -1 if other is greater.
+ */
+ @Override
+ public int compareTo(SwerveModuleState other) {
+ return Double.compare(this.speedMetersPerSecond, other.speedMetersPerSecond);
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "SwerveModuleState(Speed: %.2f m/s, Angle: %s)", speedMetersPerSecond, angle);
+ }
+
+ /**
+ * Minimize the change in heading the desired swerve module state would require by potentially
+ * reversing the direction the wheel spins. If this is used with the PIDController class's
+ * continuous input functionality, the furthest a wheel will ever rotate is 90 degrees.
+ *
+ * @param desiredState The desired state.
+ * @param currentAngle The current module angle.
+ * @return Optimized swerve module state.
+ */
+ public static SwerveModuleState optimize(
+ SwerveModuleState desiredState, Rotation2d currentAngle) {
+ var delta = desiredState.angle.minus(currentAngle);
+ if (Math.abs(delta.getDegrees()) > 90.0) {
+ return new SwerveModuleState(
+ -desiredState.speedMetersPerSecond,
+ desiredState.angle.rotateBy(Rotation2d.fromDegrees(180.0)));
+ } else {
+ return new SwerveModuleState(desiredState.speedMetersPerSecond, desiredState.angle);
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/spline/CubicHermiteSpline.java b/wpimath/src/main/java/edu/wpi/first/math/spline/CubicHermiteSpline.java
new file mode 100644
index 0000000..9bbeaf6
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/spline/CubicHermiteSpline.java
@@ -0,0 +1,137 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.spline;
+
+import org.ejml.simple.SimpleMatrix;
+
+public class CubicHermiteSpline extends Spline {
+ private static SimpleMatrix hermiteBasis;
+ private final SimpleMatrix m_coefficients;
+
+ /**
+ * Constructs a cubic hermite spline with the specified control vectors. Each control vector
+ * contains info about the location of the point and its first derivative.
+ *
+ * @param xInitialControlVector The control vector for the initial point in the x dimension.
+ * @param xFinalControlVector The control vector for the final point in the x dimension.
+ * @param yInitialControlVector The control vector for the initial point in the y dimension.
+ * @param yFinalControlVector The control vector for the final point in the y dimension.
+ */
+ @SuppressWarnings("ParameterName")
+ public CubicHermiteSpline(
+ double[] xInitialControlVector,
+ double[] xFinalControlVector,
+ double[] yInitialControlVector,
+ double[] yFinalControlVector) {
+ super(3);
+
+ // Populate the coefficients for the actual spline equations.
+ // Row 0 is x coefficients
+ // Row 1 is y coefficients
+ final var hermite = makeHermiteBasis();
+ final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
+ final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
+
+ final var xCoeffs = (hermite.mult(x)).transpose();
+ final var yCoeffs = (hermite.mult(y)).transpose();
+
+ m_coefficients = new SimpleMatrix(6, 4);
+
+ for (int i = 0; i < 4; i++) {
+ m_coefficients.set(0, i, xCoeffs.get(0, i));
+ m_coefficients.set(1, i, yCoeffs.get(0, i));
+
+ // Populate Row 2 and Row 3 with the derivatives of the equations above.
+ // Then populate row 4 and 5 with the second derivatives.
+ // Here, we are multiplying by (3 - i) to manually take the derivative. The
+ // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
+ // coefficient of the derivative, we can use the power rule and multiply
+ // the existing coefficient by its power.
+ m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
+ m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
+ }
+
+ for (int i = 0; i < 3; i++) {
+ // Here, we are multiplying by (2 - i) to manually take the derivative. The
+ // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
+ // coefficient of the derivative, we can use the power rule and multiply
+ // the existing coefficient by its power.
+ m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
+ m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
+ }
+ }
+
+ /**
+ * Returns the coefficients matrix.
+ *
+ * @return The coefficients matrix.
+ */
+ @Override
+ protected SimpleMatrix getCoefficients() {
+ return m_coefficients;
+ }
+
+ /**
+ * Returns the hermite basis matrix for cubic hermite spline interpolation.
+ *
+ * @return The hermite basis matrix for cubic hermite spline interpolation.
+ */
+ private SimpleMatrix makeHermiteBasis() {
+ if (hermiteBasis == null) {
+ // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
+ // the coefficients of the spline P(t) = a3 * t^3 + a2 * t^2 + a1 * t + a0.
+ //
+ // P(i) = P(0) = a0
+ // P'(i) = P'(0) = a1
+ // P(i+1) = P(1) = a3 + a2 + a1 + a0
+ // P'(i+1) = P'(1) = 3 * a3 + 2 * a2 + a1
+ //
+ // [ P(i) ] = [ 0 0 0 1 ][ a3 ]
+ // [ P'(i) ] = [ 0 0 1 0 ][ a2 ]
+ // [ P(i+1) ] = [ 1 1 1 1 ][ a1 ]
+ // [ P'(i+1) ] = [ 3 2 1 0 ][ a0 ]
+ //
+ // To solve for the coefficients, we can invert the 4x4 matrix and move it
+ // to the other side of the equation.
+ //
+ // [ a3 ] = [ 2 1 -2 1 ][ P(i) ]
+ // [ a2 ] = [ -3 -2 3 -1 ][ P'(i) ]
+ // [ a1 ] = [ 0 1 0 0 ][ P(i+1) ]
+ // [ a0 ] = [ 1 0 0 0 ][ P'(i+1) ]
+ hermiteBasis =
+ new SimpleMatrix(
+ 4,
+ 4,
+ true,
+ new double[] {
+ +2.0, +1.0, -2.0, +1.0, -3.0, -2.0, +3.0, -1.0, +0.0, +1.0, +0.0, +0.0, +1.0, +0.0,
+ +0.0, +0.0
+ });
+ }
+ return hermiteBasis;
+ }
+
+ /**
+ * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
+ * constructor.
+ *
+ * @param initialVector The control vector for the initial point.
+ * @param finalVector The control vector for the final point.
+ * @return The control vector matrix for a dimension.
+ */
+ private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
+ if (initialVector.length != 2 || finalVector.length != 2) {
+ throw new IllegalArgumentException("Size of vectors must be 2");
+ }
+ return new SimpleMatrix(
+ 4,
+ 1,
+ true,
+ new double[] {
+ initialVector[0], initialVector[1],
+ finalVector[0], finalVector[1]
+ });
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/spline/PoseWithCurvature.java b/wpimath/src/main/java/edu/wpi/first/math/spline/PoseWithCurvature.java
new file mode 100644
index 0000000..8bad7b1
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/spline/PoseWithCurvature.java
@@ -0,0 +1,33 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.spline;
+
+import edu.wpi.first.math.geometry.Pose2d;
+
+/** Represents a pair of a pose and a curvature. */
+@SuppressWarnings("MemberName")
+public class PoseWithCurvature {
+ // Represents the pose.
+ public Pose2d poseMeters;
+
+ // Represents the curvature.
+ public double curvatureRadPerMeter;
+
+ /**
+ * Constructs a PoseWithCurvature.
+ *
+ * @param poseMeters The pose.
+ * @param curvatureRadPerMeter The curvature.
+ */
+ public PoseWithCurvature(Pose2d poseMeters, double curvatureRadPerMeter) {
+ this.poseMeters = poseMeters;
+ this.curvatureRadPerMeter = curvatureRadPerMeter;
+ }
+
+ /** Constructs a PoseWithCurvature with default values. */
+ public PoseWithCurvature() {
+ poseMeters = new Pose2d();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/spline/QuinticHermiteSpline.java b/wpimath/src/main/java/edu/wpi/first/math/spline/QuinticHermiteSpline.java
new file mode 100644
index 0000000..4017044
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/spline/QuinticHermiteSpline.java
@@ -0,0 +1,145 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.spline;
+
+import org.ejml.simple.SimpleMatrix;
+
+public class QuinticHermiteSpline extends Spline {
+ private static SimpleMatrix hermiteBasis;
+ private final SimpleMatrix m_coefficients;
+
+ /**
+ * Constructs a quintic hermite spline with the specified control vectors. Each control vector
+ * contains into about the location of the point, its first derivative, and its second derivative.
+ *
+ * @param xInitialControlVector The control vector for the initial point in the x dimension.
+ * @param xFinalControlVector The control vector for the final point in the x dimension.
+ * @param yInitialControlVector The control vector for the initial point in the y dimension.
+ * @param yFinalControlVector The control vector for the final point in the y dimension.
+ */
+ @SuppressWarnings("ParameterName")
+ public QuinticHermiteSpline(
+ double[] xInitialControlVector,
+ double[] xFinalControlVector,
+ double[] yInitialControlVector,
+ double[] yFinalControlVector) {
+ super(5);
+
+ // Populate the coefficients for the actual spline equations.
+ // Row 0 is x coefficients
+ // Row 1 is y coefficients
+ final var hermite = makeHermiteBasis();
+ final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
+ final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
+
+ final var xCoeffs = (hermite.mult(x)).transpose();
+ final var yCoeffs = (hermite.mult(y)).transpose();
+
+ m_coefficients = new SimpleMatrix(6, 6);
+
+ for (int i = 0; i < 6; i++) {
+ m_coefficients.set(0, i, xCoeffs.get(0, i));
+ m_coefficients.set(1, i, yCoeffs.get(0, i));
+ }
+ for (int i = 0; i < 6; i++) {
+ // Populate Row 2 and Row 3 with the derivatives of the equations above.
+ // Here, we are multiplying by (5 - i) to manually take the derivative. The
+ // power of the term in index 0 is 5, index 1 is 4 and so on. To find the
+ // coefficient of the derivative, we can use the power rule and multiply
+ // the existing coefficient by its power.
+ m_coefficients.set(2, i, m_coefficients.get(0, i) * (5 - i));
+ m_coefficients.set(3, i, m_coefficients.get(1, i) * (5 - i));
+ }
+ for (int i = 0; i < 5; i++) {
+ // Then populate row 4 and 5 with the second derivatives.
+ // Here, we are multiplying by (4 - i) to manually take the derivative. The
+ // power of the term in index 0 is 4, index 1 is 3 and so on. To find the
+ // coefficient of the derivative, we can use the power rule and multiply
+ // the existing coefficient by its power.
+ m_coefficients.set(4, i, m_coefficients.get(2, i) * (4 - i));
+ m_coefficients.set(5, i, m_coefficients.get(3, i) * (4 - i));
+ }
+ }
+
+ /**
+ * Returns the coefficients matrix.
+ *
+ * @return The coefficients matrix.
+ */
+ @Override
+ protected SimpleMatrix getCoefficients() {
+ return m_coefficients;
+ }
+
+ /**
+ * Returns the hermite basis matrix for quintic hermite spline interpolation.
+ *
+ * @return The hermite basis matrix for quintic hermite spline interpolation.
+ */
+ private SimpleMatrix makeHermiteBasis() {
+ if (hermiteBasis == null) {
+ // Given P(i), P'(i), P''(i), P(i+1), P'(i+1), P''(i+1), the control
+ // vectors, we want to find the coefficients of the spline
+ // P(t) = a5 * t^5 + a4 * t^4 + a3 * t^3 + a2 * t^2 + a1 * t + a0.
+ //
+ // P(i) = P(0) = a0
+ // P'(i) = P'(0) = a1
+ // P''(i) = P''(0) = 2 * a2
+ // P(i+1) = P(1) = a5 + a4 + a3 + a2 + a1 + a0
+ // P'(i+1) = P'(1) = 5 * a5 + 4 * a4 + 3 * a3 + 2 * a2 + a1
+ // P''(i+1) = P''(1) = 20 * a5 + 12 * a4 + 6 * a3 + 2 * a2
+ //
+ // [ P(i) ] = [ 0 0 0 0 0 1 ][ a5 ]
+ // [ P'(i) ] = [ 0 0 0 0 1 0 ][ a4 ]
+ // [ P''(i) ] = [ 0 0 0 2 0 0 ][ a3 ]
+ // [ P(i+1) ] = [ 1 1 1 1 1 1 ][ a2 ]
+ // [ P'(i+1) ] = [ 5 4 3 2 1 0 ][ a1 ]
+ // [ P''(i+1) ] = [ 20 12 6 2 0 0 ][ a0 ]
+ //
+ // To solve for the coefficients, we can invert the 6x6 matrix and move it
+ // to the other side of the equation.
+ //
+ // [ a5 ] = [ -6.0 -3.0 -0.5 6.0 -3.0 0.5 ][ P(i) ]
+ // [ a4 ] = [ 15.0 8.0 1.5 -15.0 7.0 -1.0 ][ P'(i) ]
+ // [ a3 ] = [ -10.0 -6.0 -1.5 10.0 -4.0 0.5 ][ P''(i) ]
+ // [ a2 ] = [ 0.0 0.0 0.5 0.0 0.0 0.0 ][ P(i+1) ]
+ // [ a1 ] = [ 0.0 1.0 0.0 0.0 0.0 0.0 ][ P'(i+1) ]
+ // [ a0 ] = [ 1.0 0.0 0.0 0.0 0.0 0.0 ][ P''(i+1) ]
+ hermiteBasis =
+ new SimpleMatrix(
+ 6,
+ 6,
+ true,
+ new double[] {
+ -06.0, -03.0, -00.5, +06.0, -03.0, +00.5, +15.0, +08.0, +01.5, -15.0, +07.0, -01.0,
+ -10.0, -06.0, -01.5, +10.0, -04.0, +00.5, +00.0, +00.0, +00.5, +00.0, +00.0, +00.0,
+ +00.0, +01.0, +00.0, +00.0, +00.0, +00.0, +01.0, +00.0, +00.0, +00.0, +00.0, +00.0
+ });
+ }
+ return hermiteBasis;
+ }
+
+ /**
+ * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
+ * constructor.
+ *
+ * @param initialVector The control vector for the initial point.
+ * @param finalVector The control vector for the final point.
+ * @return The control vector matrix for a dimension.
+ */
+ private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
+ if (initialVector.length != 3 || finalVector.length != 3) {
+ throw new IllegalArgumentException("Size of vectors must be 3");
+ }
+ return new SimpleMatrix(
+ 6,
+ 1,
+ true,
+ new double[] {
+ initialVector[0], initialVector[1], initialVector[2],
+ finalVector[0], finalVector[1], finalVector[2]
+ });
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/Spline.java b/wpimath/src/main/java/edu/wpi/first/math/spline/Spline.java
similarity index 70%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/spline/Spline.java
rename to wpimath/src/main/java/edu/wpi/first/math/spline/Spline.java
index 57c388f..5451eea 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/Spline.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/spline/Spline.java
@@ -1,23 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.spline;
+package edu.wpi.first.math.spline;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
import java.util.Arrays;
-
import org.ejml.simple.SimpleMatrix;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-
-/**
- * Represents a two-dimensional parametric spline that interpolates between two
- * points.
- */
+/** Represents a two-dimensional parametric spline that interpolates between two points. */
public abstract class Spline {
private final int m_degree;
@@ -82,20 +74,16 @@
}
// Find the curvature.
- final double curvature =
- (dx * ddy - ddx * dy) / ((dx * dx + dy * dy) * Math.hypot(dx, dy));
+ final double curvature = (dx * ddy - ddx * dy) / ((dx * dx + dy * dy) * Math.hypot(dx, dy));
- return new PoseWithCurvature(
- new Pose2d(x, y, new Rotation2d(dx, dy)),
- curvature
- );
+ return new PoseWithCurvature(new Pose2d(x, y, new Rotation2d(dx, dy)), curvature);
}
/**
* Represents a control vector for a spline.
*
- * <p>Each element in each array represents the value of the derivative at the index. For
- * example, the value of x[2] is the second derivative in the x dimension.
+ * <p>Each element in each array represents the value of the derivative at the index. For example,
+ * the value of x[2] is the second derivative in the x dimension.
*/
@SuppressWarnings("MemberName")
public static class ControlVector {
@@ -104,6 +92,7 @@
/**
* Instantiates a control vector.
+ *
* @param x The x dimension of the control vector.
* @param y The y dimension of the control vector.
*/
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/SplineHelper.java b/wpimath/src/main/java/edu/wpi/first/math/spline/SplineHelper.java
similarity index 63%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/spline/SplineHelper.java
rename to wpimath/src/main/java/edu/wpi/first/math/spline/SplineHelper.java
index a2bf9dd..e5c67f8 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/SplineHelper.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/spline/SplineHelper.java
@@ -1,37 +1,28 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.spline;
+package edu.wpi.first.math.spline;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Translation2d;
import java.util.Arrays;
import java.util.List;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
public final class SplineHelper {
- /**
- * Private constructor because this is a utility class.
- */
- private SplineHelper() {
- }
+ /** Private constructor because this is a utility class. */
+ private SplineHelper() {}
/**
- * Returns 2 cubic control vectors from a set of exterior waypoints and
- * interior translations.
+ * Returns 2 cubic control vectors from a set of exterior waypoints and interior translations.
*
- * @param start The starting pose.
+ * @param start The starting pose.
* @param interiorWaypoints The interior waypoints.
- * @param end The ending pose.
+ * @param end The ending pose.
* @return 2 cubic control vectors.
*/
public static Spline.ControlVector[] getCubicControlVectorsFromWaypoints(
- Pose2d start, Translation2d[] interiorWaypoints, Pose2d end
- ) {
+ Pose2d start, Translation2d[] interiorWaypoints, Pose2d end) {
// Generate control vectors from poses.
Spline.ControlVector initialCV;
Spline.ControlVector endCV;
@@ -44,11 +35,11 @@
} else {
double scalar = start.getTranslation().getDistance(interiorWaypoints[0]) * 1.2;
initialCV = getCubicControlVector(scalar, start);
- scalar = end.getTranslation().getDistance(interiorWaypoints[interiorWaypoints.length - 1])
- * 1.2;
+ scalar =
+ end.getTranslation().getDistance(interiorWaypoints[interiorWaypoints.length - 1]) * 1.2;
endCV = getCubicControlVector(scalar, end);
}
- return new Spline.ControlVector[]{initialCV, endCV};
+ return new Spline.ControlVector[] {initialCV, endCV};
}
/**
@@ -57,7 +48,6 @@
* @param waypoints The waypoints
* @return List of splines.
*/
- @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
public static QuinticHermiteSpline[] getQuinticSplinesFromWaypoints(List<Pose2d> waypoints) {
QuinticHermiteSpline[] splines = new QuinticHermiteSpline[waypoints.size() - 1];
for (int i = 0; i < waypoints.size() - 1; ++i) {
@@ -70,30 +60,27 @@
var controlVecA = getQuinticControlVector(scalar, p0);
var controlVecB = getQuinticControlVector(scalar, p1);
- splines[i]
- = new QuinticHermiteSpline(controlVecA.x, controlVecB.x, controlVecA.y, controlVecB.y);
+ splines[i] =
+ new QuinticHermiteSpline(controlVecA.x, controlVecB.x, controlVecA.y, controlVecB.y);
}
return splines;
}
/**
- * Returns a set of cubic splines corresponding to the provided control vectors. The
- * user is free to set the direction of the start and end point. The
- * directions for the middle waypoints are determined automatically to ensure
- * continuous curvature throughout the path.
+ * Returns a set of cubic splines corresponding to the provided control vectors. The user is free
+ * to set the direction of the start and end point. The directions for the middle waypoints are
+ * determined automatically to ensure continuous curvature throughout the path.
*
- * @param start The starting control vector.
- * @param waypoints The middle waypoints. This can be left blank if you only
- * wish to create a path with two waypoints.
- * @param end The ending control vector.
- * @return A vector of cubic hermite splines that interpolate through the
- * provided waypoints and control vectors.
+ * @param start The starting control vector.
+ * @param waypoints The middle waypoints. This can be left blank if you only wish to create a path
+ * with two waypoints.
+ * @param end The ending control vector.
+ * @return A vector of cubic hermite splines that interpolate through the provided waypoints and
+ * control vectors.
*/
- @SuppressWarnings({"LocalVariableName", "PMD.ExcessiveMethodLength",
- "PMD.AvoidInstantiatingObjectsInLoops"})
+ @SuppressWarnings("LocalVariableName")
public static CubicHermiteSpline[] getCubicSplinesFromControlVectors(
Spline.ControlVector start, Translation2d[] waypoints, Spline.ControlVector end) {
-
CubicHermiteSpline[] splines = new CubicHermiteSpline[waypoints.length + 1];
double[] xInitial = start.x;
@@ -154,10 +141,12 @@
}
}
- dx[dx.length - 1] = 3 * (newWaypts[newWaypts.length - 1].getX()
- - newWaypts[newWaypts.length - 3].getX()) - xFinal[1];
- dy[dy.length - 1] = 3 * (newWaypts[newWaypts.length - 1].getY()
- - newWaypts[newWaypts.length - 3].getY()) - yFinal[1];
+ dx[dx.length - 1] =
+ 3 * (newWaypts[newWaypts.length - 1].getX() - newWaypts[newWaypts.length - 3].getX())
+ - xFinal[1];
+ dy[dy.length - 1] =
+ 3 * (newWaypts[newWaypts.length - 1].getY() - newWaypts[newWaypts.length - 3].getY())
+ - yFinal[1];
// Compute solution to tridiagonal system
thomasAlgorithm(a, b, c, dx, fx);
@@ -174,47 +163,46 @@
newFy[newFy.length - 1] = yFinal[1];
for (int i = 0; i < newFx.length - 1; i++) {
- splines[i] = new CubicHermiteSpline(
- new double[]{newWaypts[i].getX(), newFx[i]},
- new double[]{newWaypts[i + 1].getX(), newFx[i + 1]},
- new double[]{newWaypts[i].getY(), newFy[i]},
- new double[]{newWaypts[i + 1].getY(), newFy[i + 1]}
- );
+ splines[i] =
+ new CubicHermiteSpline(
+ new double[] {newWaypts[i].getX(), newFx[i]},
+ new double[] {newWaypts[i + 1].getX(), newFx[i + 1]},
+ new double[] {newWaypts[i].getY(), newFy[i]},
+ new double[] {newWaypts[i + 1].getY(), newFy[i + 1]});
}
} else if (waypoints.length == 1) {
- final var xDeriv = (3 * (xFinal[0]
- - xInitial[0])
- - xFinal[1] - xInitial[1])
- / 4.0;
- final var yDeriv = (3 * (yFinal[0]
- - yInitial[0])
- - yFinal[1] - yInitial[1])
- / 4.0;
+ final var xDeriv = (3 * (xFinal[0] - xInitial[0]) - xFinal[1] - xInitial[1]) / 4.0;
+ final var yDeriv = (3 * (yFinal[0] - yInitial[0]) - yFinal[1] - yInitial[1]) / 4.0;
double[] midXControlVector = {waypoints[0].getX(), xDeriv};
double[] midYControlVector = {waypoints[0].getY(), yDeriv};
- splines[0] = new CubicHermiteSpline(xInitial, midXControlVector,
- yInitial, midYControlVector);
- splines[1] = new CubicHermiteSpline(midXControlVector, xFinal,
- midYControlVector, yFinal);
+ splines[0] =
+ new CubicHermiteSpline(
+ xInitial, midXControlVector,
+ yInitial, midYControlVector);
+ splines[1] =
+ new CubicHermiteSpline(
+ midXControlVector, xFinal,
+ midYControlVector, yFinal);
} else {
- splines[0] = new CubicHermiteSpline(xInitial, xFinal,
- yInitial, yFinal);
+ splines[0] =
+ new CubicHermiteSpline(
+ xInitial, xFinal,
+ yInitial, yFinal);
}
return splines;
}
/**
- * Returns a set of quintic splines corresponding to the provided control vectors.
- * The user is free to set the direction of all control vectors. Continuous
- * curvature is guaranteed throughout the path.
+ * Returns a set of quintic splines corresponding to the provided control vectors. The user is
+ * free to set the direction of all control vectors. Continuous curvature is guaranteed throughout
+ * the path.
*
* @param controlVectors The control vectors.
- * @return A vector of quintic hermite splines that interpolate through the
- * provided waypoints.
+ * @return A vector of quintic hermite splines that interpolate through the provided waypoints.
*/
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
+ @SuppressWarnings("LocalVariableName")
public static QuinticHermiteSpline[] getQuinticSplinesFromControlVectors(
Spline.ControlVector[] controlVectors) {
QuinticHermiteSpline[] splines = new QuinticHermiteSpline[controlVectors.length - 1];
@@ -223,8 +211,10 @@
var xFinal = controlVectors[i + 1].x;
var yInitial = controlVectors[i].y;
var yFinal = controlVectors[i + 1].y;
- splines[i] = new QuinticHermiteSpline(xInitial, xFinal,
- yInitial, yFinal);
+ splines[i] =
+ new QuinticHermiteSpline(
+ xInitial, xFinal,
+ yInitial, yFinal);
}
return splines;
}
@@ -232,15 +222,15 @@
/**
* Thomas algorithm for solving tridiagonal systems Af = d.
*
- * @param a the values of A above the diagonal
- * @param b the values of A on the diagonal
- * @param c the values of A below the diagonal
- * @param d the vector on the rhs
+ * @param a the values of A above the diagonal
+ * @param b the values of A on the diagonal
+ * @param c the values of A below the diagonal
+ * @param d the vector on the rhs
* @param solutionVector the unknown (solution) vector, modified in-place
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
- private static void thomasAlgorithm(double[] a, double[] b,
- double[] c, double[] d, double[] solutionVector) {
+ private static void thomasAlgorithm(
+ double[] a, double[] b, double[] c, double[] d, double[] solutionVector) {
int N = d.length;
double[] cStar = new double[N];
@@ -266,15 +256,13 @@
private static Spline.ControlVector getCubicControlVector(double scalar, Pose2d point) {
return new Spline.ControlVector(
- new double[]{point.getX(), scalar * point.getRotation().getCos()},
- new double[]{point.getY(), scalar * point.getRotation().getSin()}
- );
+ new double[] {point.getX(), scalar * point.getRotation().getCos()},
+ new double[] {point.getY(), scalar * point.getRotation().getSin()});
}
private static Spline.ControlVector getQuinticControlVector(double scalar, Pose2d point) {
return new Spline.ControlVector(
- new double[]{point.getX(), scalar * point.getRotation().getCos(), 0.0},
- new double[]{point.getY(), scalar * point.getRotation().getSin(), 0.0}
- );
+ new double[] {point.getX(), scalar * point.getRotation().getCos(), 0.0},
+ new double[] {point.getY(), scalar * point.getRotation().getSin(), 0.0});
}
}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/SplineParameterizer.java b/wpimath/src/main/java/edu/wpi/first/math/spline/SplineParameterizer.java
similarity index 67%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/spline/SplineParameterizer.java
rename to wpimath/src/main/java/edu/wpi/first/math/spline/SplineParameterizer.java
index 1585214..88afc6d 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/SplineParameterizer.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/spline/SplineParameterizer.java
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
/*
* MIT License
@@ -29,25 +26,23 @@
* SOFTWARE.
*/
-package edu.wpi.first.wpilibj.spline;
+package edu.wpi.first.math.spline;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
-/**
- * Class used to parameterize a spline by its arc length.
- */
+/** Class used to parameterize a spline by its arc length. */
public final class SplineParameterizer {
private static final double kMaxDx = 0.127;
private static final double kMaxDy = 0.00127;
private static final double kMaxDtheta = 0.0872;
/**
- * A malformed spline does not actually explode the LIFO stack size. Instead, the stack size
- * stays at a relatively small number (e.g. 30) and never decreases. Because of this, we must
- * count iterations. Even long, complex paths don't usually go over 300 iterations, so hitting
- * this maximum should definitely indicate something has gone wrong.
+ * A malformed spline does not actually explode the LIFO stack size. Instead, the stack size stays
+ * at a relatively small number (e.g. 30) and never decreases. Because of this, we must count
+ * iterations. Even long, complex paths don't usually go over 300 iterations, so hitting this
+ * maximum should definitely indicate something has gone wrong.
*/
private static final int kMaxIterations = 5000;
@@ -74,37 +69,33 @@
}
}
- /**
- * Private constructor because this is a utility class.
- */
- private SplineParameterizer() {
- }
+ /** Private constructor because this is a utility class. */
+ private SplineParameterizer() {}
/**
- * Parameterizes the spline. This method breaks up the spline into various
- * arcs until their dx, dy, and dtheta are within specific tolerances.
+ * Parameterizes the spline. This method breaks up the spline into various arcs until their dx,
+ * dy, and dtheta are within specific tolerances.
*
* @param spline The spline to parameterize.
* @return A list of poses and curvatures that represents various points on the spline.
* @throws MalformedSplineException When the spline is malformed (e.g. has close adjacent points
- * with approximately opposing headings)
+ * with approximately opposing headings)
*/
public static List<PoseWithCurvature> parameterize(Spline spline) {
return parameterize(spline, 0.0, 1.0);
}
/**
- * Parameterizes the spline. This method breaks up the spline into various
- * arcs until their dx, dy, and dtheta are within specific tolerances.
+ * Parameterizes the spline. This method breaks up the spline into various arcs until their dx,
+ * dy, and dtheta are within specific tolerances.
*
* @param spline The spline to parameterize.
- * @param t0 Starting internal spline parameter. It is recommended to use 0.0.
- * @param t1 Ending internal spline parameter. It is recommended to use 1.0.
- * @return A list of poses and curvatures that represents various points on the spline.
+ * @param t0 Starting internal spline parameter. It is recommended to use 0.0.
+ * @param t1 Ending internal spline parameter. It is recommended to use 1.0.
+ * @return A list of poses and curvatures that represents various points on the spline.
* @throws MalformedSplineException When the spline is malformed (e.g. has close adjacent points
- * with approximately opposing headings)
+ * with approximately opposing headings)
*/
- @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
public static List<PoseWithCurvature> parameterize(Spline spline, double t0, double t1) {
var splinePoints = new ArrayList<PoseWithCurvature>();
@@ -127,11 +118,9 @@
end = spline.getPoint(current.t1);
final var twist = start.poseMeters.log(end.poseMeters);
- if (
- Math.abs(twist.dy) > kMaxDy
+ if (Math.abs(twist.dy) > kMaxDy
|| Math.abs(twist.dx) > kMaxDx
- || Math.abs(twist.dtheta) > kMaxDtheta
- ) {
+ || Math.abs(twist.dtheta) > kMaxDtheta) {
stack.addFirst(new StackContents((current.t0 + current.t1) / 2, current.t1));
stack.addFirst(new StackContents(current.t0, (current.t0 + current.t1) / 2));
} else {
@@ -141,10 +130,9 @@
iterations++;
if (iterations >= kMaxIterations) {
throw new MalformedSplineException(
- "Could not parameterize a malformed spline. "
- + "This means that you probably had two or more adjacent waypoints that were very close "
- + "together with headings in opposing directions."
- );
+ "Could not parameterize a malformed spline. This means that you probably had two or "
+ + " more adjacent waypoints that were very close together with headings in "
+ + "opposing directions.");
}
}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/Discretization.java b/wpimath/src/main/java/edu/wpi/first/math/system/Discretization.java
new file mode 100644
index 0000000..ffbd99e
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/Discretization.java
@@ -0,0 +1,167 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.Pair;
+import org.ejml.simple.SimpleMatrix;
+
+@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+public final class Discretization {
+ private Discretization() {
+ // Utility class
+ }
+
+ /**
+ * Discretizes the given continuous A matrix.
+ *
+ * @param <States> Num representing the number of states.
+ * @param contA Continuous system matrix.
+ * @param dtSeconds Discretization timestep.
+ * @return the discrete matrix system.
+ */
+ public static <States extends Num> Matrix<States, States> discretizeA(
+ Matrix<States, States> contA, double dtSeconds) {
+ return contA.times(dtSeconds).exp();
+ }
+
+ /**
+ * Discretizes the given continuous A and B matrices.
+ *
+ * @param <States> Nat representing the states of the system.
+ * @param <Inputs> Nat representing the inputs to the system.
+ * @param contA Continuous system matrix.
+ * @param contB Continuous input matrix.
+ * @param dtSeconds Discretization timestep.
+ * @return a Pair representing discA and diskB.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public static <States extends Num, Inputs extends Num>
+ Pair<Matrix<States, States>, Matrix<States, Inputs>> discretizeAB(
+ Matrix<States, States> contA, Matrix<States, Inputs> contB, double dtSeconds) {
+ var scaledA = contA.times(dtSeconds);
+ var scaledB = contB.times(dtSeconds);
+
+ int states = contA.getNumRows();
+ int inputs = contB.getNumCols();
+ var M = new Matrix<>(new SimpleMatrix(states + inputs, states + inputs));
+ M.assignBlock(0, 0, scaledA);
+ M.assignBlock(0, scaledA.getNumCols(), scaledB);
+ var phi = M.exp();
+
+ var discA = new Matrix<States, States>(new SimpleMatrix(states, states));
+ var discB = new Matrix<States, Inputs>(new SimpleMatrix(states, inputs));
+
+ discA.extractFrom(0, 0, phi);
+ discB.extractFrom(0, contB.getNumRows(), phi);
+
+ return new Pair<>(discA, discB);
+ }
+
+ /**
+ * Discretizes the given continuous A and Q matrices.
+ *
+ * @param <States> Nat representing the number of states.
+ * @param contA Continuous system matrix.
+ * @param contQ Continuous process noise covariance matrix.
+ * @param dtSeconds Discretization timestep.
+ * @return a pair representing the discrete system matrix and process noise covariance matrix.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public static <States extends Num>
+ Pair<Matrix<States, States>, Matrix<States, States>> discretizeAQ(
+ Matrix<States, States> contA, Matrix<States, States> contQ, double dtSeconds) {
+ int states = contA.getNumRows();
+
+ // Make continuous Q symmetric if it isn't already
+ Matrix<States, States> Q = contQ.plus(contQ.transpose()).div(2.0);
+
+ // Set up the matrix M = [[-A, Q], [0, A.T]]
+ final var M = new Matrix<>(new SimpleMatrix(2 * states, 2 * states));
+ M.assignBlock(0, 0, contA.times(-1.0));
+ M.assignBlock(0, states, Q);
+ M.assignBlock(states, 0, new Matrix<>(new SimpleMatrix(states, states)));
+ M.assignBlock(states, states, contA.transpose());
+
+ final var phi = M.times(dtSeconds).exp();
+
+ // Phi12 = phi[0:States, States:2*States]
+ // Phi22 = phi[States:2*States, States:2*States]
+ final Matrix<States, States> phi12 = phi.block(states, states, 0, states);
+ final Matrix<States, States> phi22 = phi.block(states, states, states, states);
+
+ final var discA = phi22.transpose();
+
+ Q = discA.times(phi12);
+
+ // Make discrete Q symmetric if it isn't already
+ final var discQ = Q.plus(Q.transpose()).div(2.0);
+
+ return new Pair<>(discA, discQ);
+ }
+
+ /**
+ * Discretizes the given continuous A and Q matrices.
+ *
+ * <p>Rather than solving a 2N x 2N matrix exponential like in DiscretizeQ() (which is expensive),
+ * we take advantage of the structure of the block matrix of A and Q.
+ *
+ * <p>The exponential of A*t, which is only N x N, is relatively cheap. 2) The upper-right quarter
+ * of the 2N x 2N matrix, which we can approximate using a taylor series to several terms and
+ * still be substantially cheaper than taking the big exponential.
+ *
+ * @param <States> Nat representing the number of states.
+ * @param contA Continuous system matrix.
+ * @param contQ Continuous process noise covariance matrix.
+ * @param dtSeconds Discretization timestep.
+ * @return a pair representing the discrete system matrix and process noise covariance matrix.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public static <States extends Num>
+ Pair<Matrix<States, States>, Matrix<States, States>> discretizeAQTaylor(
+ Matrix<States, States> contA, Matrix<States, States> contQ, double dtSeconds) {
+ // Make continuous Q symmetric if it isn't already
+ Matrix<States, States> Q = contQ.plus(contQ.transpose()).div(2.0);
+
+ Matrix<States, States> lastTerm = Q.copy();
+ double lastCoeff = dtSeconds;
+
+ // Aᵀⁿ
+ Matrix<States, States> Atn = contA.transpose();
+ Matrix<States, States> phi12 = lastTerm.times(lastCoeff);
+
+ // i = 6 i.e. 5th order should be enough precision
+ for (int i = 2; i < 6; ++i) {
+ lastTerm = contA.times(-1).times(lastTerm).plus(Q.times(Atn));
+ lastCoeff *= dtSeconds / ((double) i);
+
+ phi12 = phi12.plus(lastTerm.times(lastCoeff));
+
+ Atn = Atn.times(contA.transpose());
+ }
+
+ var discA = discretizeA(contA, dtSeconds);
+ Q = discA.times(phi12);
+
+ // Make Q symmetric if it isn't already
+ var discQ = Q.plus(Q.transpose()).div(2.0);
+
+ return new Pair<>(discA, discQ);
+ }
+
+ /**
+ * Returns a discretized version of the provided continuous measurement noise covariance matrix.
+ * Note that dt=0.0 divides R by zero.
+ *
+ * @param <O> Nat representing the number of outputs.
+ * @param R Continuous measurement noise covariance matrix.
+ * @param dtSeconds Discretization timestep.
+ * @return Discretized version of the provided continuous measurement noise covariance matrix.
+ */
+ public static <O extends Num> Matrix<O, O> discretizeR(Matrix<O, O> R, double dtSeconds) {
+ return R.div(dtSeconds);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/LinearSystem.java b/wpimath/src/main/java/edu/wpi/first/math/system/LinearSystem.java
new file mode 100644
index 0000000..9e35cfd
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/LinearSystem.java
@@ -0,0 +1,201 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+
+@SuppressWarnings("ClassTypeParameterName")
+public class LinearSystem<States extends Num, Inputs extends Num, Outputs extends Num> {
+ /** Continuous system matrix. */
+ @SuppressWarnings("MemberName")
+ private final Matrix<States, States> m_A;
+
+ /** Continuous input matrix. */
+ @SuppressWarnings("MemberName")
+ private final Matrix<States, Inputs> m_B;
+
+ /** Output matrix. */
+ @SuppressWarnings("MemberName")
+ private final Matrix<Outputs, States> m_C;
+
+ /** Feedthrough matrix. */
+ @SuppressWarnings("MemberName")
+ private final Matrix<Outputs, Inputs> m_D;
+
+ /**
+ * Construct a new LinearSystem from the four system matrices.
+ *
+ * @param a The system matrix A.
+ * @param b The input matrix B.
+ * @param c The output matrix C.
+ * @param d The feedthrough matrix D.
+ * @throws IllegalArgumentException if any matrix element isn't finite.
+ */
+ @SuppressWarnings("ParameterName")
+ public LinearSystem(
+ Matrix<States, States> a,
+ Matrix<States, Inputs> b,
+ Matrix<Outputs, States> c,
+ Matrix<Outputs, Inputs> d) {
+ for (int row = 0; row < a.getNumRows(); ++row) {
+ for (int col = 0; col < a.getNumCols(); ++col) {
+ if (!Double.isFinite(a.get(row, col))) {
+ throw new IllegalArgumentException(
+ "Elements of A aren't finite. This is usually due to model implementation errors.");
+ }
+ }
+ }
+ for (int row = 0; row < b.getNumRows(); ++row) {
+ for (int col = 0; col < b.getNumCols(); ++col) {
+ if (!Double.isFinite(b.get(row, col))) {
+ throw new IllegalArgumentException(
+ "Elements of B aren't finite. This is usually due to model implementation errors.");
+ }
+ }
+ }
+ for (int row = 0; row < c.getNumRows(); ++row) {
+ for (int col = 0; col < c.getNumCols(); ++col) {
+ if (!Double.isFinite(c.get(row, col))) {
+ throw new IllegalArgumentException(
+ "Elements of C aren't finite. This is usually due to model implementation errors.");
+ }
+ }
+ }
+ for (int row = 0; row < d.getNumRows(); ++row) {
+ for (int col = 0; col < d.getNumCols(); ++col) {
+ if (!Double.isFinite(d.get(row, col))) {
+ throw new IllegalArgumentException(
+ "Elements of D aren't finite. This is usually due to model implementation errors.");
+ }
+ }
+ }
+
+ this.m_A = a;
+ this.m_B = b;
+ this.m_C = c;
+ this.m_D = d;
+ }
+
+ /**
+ * Returns the system matrix A.
+ *
+ * @return the system matrix A.
+ */
+ public Matrix<States, States> getA() {
+ return m_A;
+ }
+
+ /**
+ * Returns an element of the system matrix A.
+ *
+ * @param row Row of A.
+ * @param col Column of A.
+ * @return the system matrix A at (i, j).
+ */
+ public double getA(int row, int col) {
+ return m_A.get(row, col);
+ }
+
+ /**
+ * Returns the input matrix B.
+ *
+ * @return the input matrix B.
+ */
+ public Matrix<States, Inputs> getB() {
+ return m_B;
+ }
+
+ /**
+ * Returns an element of the input matrix B.
+ *
+ * @param row Row of B.
+ * @param col Column of B.
+ * @return The value of the input matrix B at (i, j).
+ */
+ public double getB(int row, int col) {
+ return m_B.get(row, col);
+ }
+
+ /**
+ * Returns the output matrix C.
+ *
+ * @return Output matrix C.
+ */
+ public Matrix<Outputs, States> getC() {
+ return m_C;
+ }
+
+ /**
+ * Returns an element of the output matrix C.
+ *
+ * @param row Row of C.
+ * @param col Column of C.
+ * @return the double value of C at the given position.
+ */
+ public double getC(int row, int col) {
+ return m_C.get(row, col);
+ }
+
+ /**
+ * Returns the feedthrough matrix D.
+ *
+ * @return the feedthrough matrix D.
+ */
+ public Matrix<Outputs, Inputs> getD() {
+ return m_D;
+ }
+
+ /**
+ * Returns an element of the feedthrough matrix D.
+ *
+ * @param row Row of D.
+ * @param col Column of D.
+ * @return The feedthrough matrix D at (i, j).
+ */
+ public double getD(int row, int col) {
+ return m_D.get(row, col);
+ }
+
+ /**
+ * Computes the new x given the old x and the control input.
+ *
+ * <p>This is used by state observers directly to run updates based on state estimate.
+ *
+ * @param x The current state.
+ * @param clampedU The control input.
+ * @param dtSeconds Timestep for model update.
+ * @return the updated x.
+ */
+ @SuppressWarnings("ParameterName")
+ public Matrix<States, N1> calculateX(
+ Matrix<States, N1> x, Matrix<Inputs, N1> clampedU, double dtSeconds) {
+ var discABpair = Discretization.discretizeAB(m_A, m_B, dtSeconds);
+
+ return (discABpair.getFirst().times(x)).plus(discABpair.getSecond().times(clampedU));
+ }
+
+ /**
+ * Computes the new y given the control input.
+ *
+ * <p>This is used by state observers directly to run updates based on state estimate.
+ *
+ * @param x The current state.
+ * @param clampedU The control input.
+ * @return the updated output matrix Y.
+ */
+ @SuppressWarnings("ParameterName")
+ public Matrix<Outputs, N1> calculateY(Matrix<States, N1> x, Matrix<Inputs, N1> clampedU) {
+ return m_C.times(x).plus(m_D.times(clampedU));
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "Linear System: A\n%s\n\nB:\n%s\n\nC:\n%s\n\nD:\n%s\n",
+ m_A.toString(), m_B.toString(), m_C.toString(), m_D.toString());
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/LinearSystemLoop.java b/wpimath/src/main/java/edu/wpi/first/math/system/LinearSystemLoop.java
new file mode 100644
index 0000000..8f3da6a
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/LinearSystemLoop.java
@@ -0,0 +1,349 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.controller.LinearPlantInversionFeedforward;
+import edu.wpi.first.math.controller.LinearQuadraticRegulator;
+import edu.wpi.first.math.estimator.KalmanFilter;
+import edu.wpi.first.math.numbers.N1;
+import java.util.function.Function;
+import org.ejml.MatrixDimensionException;
+import org.ejml.simple.SimpleMatrix;
+
+/**
+ * Combines a controller, feedforward, and observer for controlling a mechanism with full state
+ * feedback.
+ *
+ * <p>For everything in this file, "inputs" and "outputs" are defined from the perspective of the
+ * plant. This means U is an input and Y is an output (because you give the plant U (powers) and it
+ * gives you back a Y (sensor values). This is the opposite of what they mean from the perspective
+ * of the controller (U is an output because that's what goes to the motors and Y is an input
+ * because that's what comes back from the sensors).
+ *
+ * <p>For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ */
+@SuppressWarnings("ClassTypeParameterName")
+public class LinearSystemLoop<States extends Num, Inputs extends Num, Outputs extends Num> {
+ private final LinearQuadraticRegulator<States, Inputs, Outputs> m_controller;
+ private final LinearPlantInversionFeedforward<States, Inputs, Outputs> m_feedforward;
+ private final KalmanFilter<States, Inputs, Outputs> m_observer;
+ private Matrix<States, N1> m_nextR;
+ private Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> m_clampFunction;
+
+ /**
+ * Constructs a state-space loop with the given plant, controller, and observer. By default, the
+ * initial reference is all zeros. Users should call reset with the initial system state before
+ * enabling the loop. This constructor assumes that the input(s) to this system are voltage.
+ *
+ * @param plant State-space plant.
+ * @param controller State-space controller.
+ * @param observer State-space observer.
+ * @param maxVoltageVolts The maximum voltage that can be applied. Commonly 12.
+ * @param dtSeconds The nominal timestep.
+ */
+ public LinearSystemLoop(
+ LinearSystem<States, Inputs, Outputs> plant,
+ LinearQuadraticRegulator<States, Inputs, Outputs> controller,
+ KalmanFilter<States, Inputs, Outputs> observer,
+ double maxVoltageVolts,
+ double dtSeconds) {
+ this(
+ controller,
+ new LinearPlantInversionFeedforward<>(plant, dtSeconds),
+ observer,
+ u -> StateSpaceUtil.normalizeInputVector(u, maxVoltageVolts));
+ }
+
+ /**
+ * Constructs a state-space loop with the given plant, controller, and observer. By default, the
+ * initial reference is all zeros. Users should call reset with the initial system state before
+ * enabling the loop.
+ *
+ * @param plant State-space plant.
+ * @param controller State-space controller.
+ * @param observer State-space observer.
+ * @param clampFunction The function used to clamp the input U.
+ * @param dtSeconds The nominal timestep.
+ */
+ public LinearSystemLoop(
+ LinearSystem<States, Inputs, Outputs> plant,
+ LinearQuadraticRegulator<States, Inputs, Outputs> controller,
+ KalmanFilter<States, Inputs, Outputs> observer,
+ Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction,
+ double dtSeconds) {
+ this(
+ controller,
+ new LinearPlantInversionFeedforward<>(plant, dtSeconds),
+ observer,
+ clampFunction);
+ }
+
+ /**
+ * Constructs a state-space loop with the given controller, feedforward and observer. By default,
+ * the initial reference is all zeros. Users should call reset with the initial system state
+ * before enabling the loop.
+ *
+ * @param controller State-space controller.
+ * @param feedforward Plant inversion feedforward.
+ * @param observer State-space observer.
+ * @param maxVoltageVolts The maximum voltage that can be applied. Assumes that the inputs are
+ * voltages.
+ */
+ public LinearSystemLoop(
+ LinearQuadraticRegulator<States, Inputs, Outputs> controller,
+ LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
+ KalmanFilter<States, Inputs, Outputs> observer,
+ double maxVoltageVolts) {
+ this(
+ controller,
+ feedforward,
+ observer,
+ u -> StateSpaceUtil.normalizeInputVector(u, maxVoltageVolts));
+ }
+
+ /**
+ * Constructs a state-space loop with the given controller, feedforward, and observer. By default,
+ * the initial reference is all zeros. Users should call reset with the initial system state
+ * before enabling the loop.
+ *
+ * @param controller State-space controller.
+ * @param feedforward Plant inversion feedforward.
+ * @param observer State-space observer.
+ * @param clampFunction The function used to clamp the input U.
+ */
+ public LinearSystemLoop(
+ LinearQuadraticRegulator<States, Inputs, Outputs> controller,
+ LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
+ KalmanFilter<States, Inputs, Outputs> observer,
+ Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
+ this.m_controller = controller;
+ this.m_feedforward = feedforward;
+ this.m_observer = observer;
+ this.m_clampFunction = clampFunction;
+
+ m_nextR = new Matrix<>(new SimpleMatrix(controller.getK().getNumCols(), 1));
+ reset(m_nextR);
+ }
+
+ /**
+ * Returns the observer's state estimate x-hat.
+ *
+ * @return the observer's state estimate x-hat.
+ */
+ public Matrix<States, N1> getXHat() {
+ return getObserver().getXhat();
+ }
+
+ /**
+ * Returns an element of the observer's state estimate x-hat.
+ *
+ * @param row Row of x-hat.
+ * @return the i-th element of the observer's state estimate x-hat.
+ */
+ public double getXHat(int row) {
+ return getObserver().getXhat(row);
+ }
+
+ /**
+ * Set the initial state estimate x-hat.
+ *
+ * @param xhat The initial state estimate x-hat.
+ */
+ public void setXHat(Matrix<States, N1> xhat) {
+ getObserver().setXhat(xhat);
+ }
+
+ /**
+ * Set an element of the initial state estimate x-hat.
+ *
+ * @param row Row of x-hat.
+ * @param value Value for element of x-hat.
+ */
+ public void setXHat(int row, double value) {
+ getObserver().setXhat(row, value);
+ }
+
+ /**
+ * Returns an element of the controller's next reference r.
+ *
+ * @param row Row of r.
+ * @return the element i of the controller's next reference r.
+ */
+ public double getNextR(int row) {
+ return getNextR().get(row, 0);
+ }
+
+ /**
+ * Returns the controller's next reference r.
+ *
+ * @return the controller's next reference r.
+ */
+ public Matrix<States, N1> getNextR() {
+ return m_nextR;
+ }
+
+ /**
+ * Set the next reference r.
+ *
+ * @param nextR Next reference.
+ */
+ public void setNextR(Matrix<States, N1> nextR) {
+ m_nextR = nextR;
+ }
+
+ /**
+ * Set the next reference r.
+ *
+ * @param nextR Next reference.
+ */
+ public void setNextR(double... nextR) {
+ if (nextR.length != m_nextR.getNumRows()) {
+ throw new MatrixDimensionException(
+ String.format(
+ "The next reference does not have the "
+ + "correct number of entries! Expected %s, but got %s.",
+ m_nextR.getNumRows(), nextR.length));
+ }
+ m_nextR = new Matrix<>(new SimpleMatrix(m_nextR.getNumRows(), 1, true, nextR));
+ }
+
+ /**
+ * Returns the controller's calculated control input u plus the calculated feedforward u_ff.
+ *
+ * @return the calculated control input u.
+ */
+ public Matrix<Inputs, N1> getU() {
+ return clampInput(m_controller.getU().plus(m_feedforward.getUff()));
+ }
+
+ /**
+ * Returns an element of the controller's calculated control input u.
+ *
+ * @param row Row of u.
+ * @return the calculated control input u at the row i.
+ */
+ public double getU(int row) {
+ return getU().get(row, 0);
+ }
+
+ /**
+ * Return the controller used internally.
+ *
+ * @return the controller used internally.
+ */
+ public LinearQuadraticRegulator<States, Inputs, Outputs> getController() {
+ return m_controller;
+ }
+
+ /**
+ * Return the feedforward used internally.
+ *
+ * @return the feedforward used internally.
+ */
+ public LinearPlantInversionFeedforward<States, Inputs, Outputs> getFeedforward() {
+ return m_feedforward;
+ }
+
+ /**
+ * Return the observer used internally.
+ *
+ * @return the observer used internally.
+ */
+ public KalmanFilter<States, Inputs, Outputs> getObserver() {
+ return m_observer;
+ }
+
+ /**
+ * Zeroes reference r and controller output u. The previous reference of the
+ * PlantInversionFeedforward and the initial state estimate of the KalmanFilter are set to the
+ * initial state provided.
+ *
+ * @param initialState The initial state.
+ */
+ public void reset(Matrix<States, N1> initialState) {
+ m_nextR.fill(0.0);
+ m_controller.reset();
+ m_feedforward.reset(initialState);
+ m_observer.setXhat(initialState);
+ }
+
+ /**
+ * Returns difference between reference r and current state x-hat.
+ *
+ * @return The state error matrix.
+ */
+ public Matrix<States, N1> getError() {
+ return getController().getR().minus(m_observer.getXhat());
+ }
+
+ /**
+ * Returns difference between reference r and current state x-hat.
+ *
+ * @param index The index of the error matrix to return.
+ * @return The error at that index.
+ */
+ public double getError(int index) {
+ return (getController().getR().minus(m_observer.getXhat())).get(index, 0);
+ }
+
+ /**
+ * Get the function used to clamp the input u.
+ *
+ * @return The clamping function.
+ */
+ public Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> getClampFunction() {
+ return m_clampFunction;
+ }
+
+ /**
+ * Set the clamping function used to clamp inputs.
+ *
+ * @param clampFunction The clamping function.
+ */
+ public void setClampFunction(Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
+ this.m_clampFunction = clampFunction;
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * @param y Measurement vector.
+ */
+ @SuppressWarnings("ParameterName")
+ public void correct(Matrix<Outputs, N1> y) {
+ getObserver().correct(getU(), y);
+ }
+
+ /**
+ * Sets new controller output, projects model forward, and runs observer prediction.
+ *
+ * <p>After calling this, the user should send the elements of u to the actuators.
+ *
+ * @param dtSeconds Timestep for model update.
+ */
+ @SuppressWarnings("LocalVariableName")
+ public void predict(double dtSeconds) {
+ var u =
+ clampInput(
+ m_controller
+ .calculate(getObserver().getXhat(), m_nextR)
+ .plus(m_feedforward.calculate(m_nextR)));
+ getObserver().predict(u, dtSeconds);
+ }
+
+ /**
+ * Clamp the input u to the min and max.
+ *
+ * @param unclampedU The input to clamp.
+ * @return The clamped input.
+ */
+ public Matrix<Inputs, N1> clampInput(Matrix<Inputs, N1> unclampedU) {
+ return m_clampFunction.apply(unclampedU);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/NumericalIntegration.java b/wpimath/src/main/java/edu/wpi/first/math/system/NumericalIntegration.java
new file mode 100644
index 0000000..274b10a
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/NumericalIntegration.java
@@ -0,0 +1,383 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+import java.util.function.BiFunction;
+import java.util.function.DoubleFunction;
+import java.util.function.Function;
+
+public final class NumericalIntegration {
+ private NumericalIntegration() {
+ // utility Class
+ }
+
+ /**
+ * Performs Runge Kutta integration (4th order).
+ *
+ * @param f The function to integrate, which takes one argument x.
+ * @param x The initial value of x.
+ * @param dtSeconds The time over which to integrate.
+ * @return the integration of dx/dt = f(x) for dt.
+ */
+ @SuppressWarnings("ParameterName")
+ public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) {
+ final var h = dtSeconds;
+ final var k1 = f.apply(x);
+ final var k2 = f.apply(x + h * k1 * 0.5);
+ final var k3 = f.apply(x + h * k2 * 0.5);
+ final var k4 = f.apply(x + h * k3);
+
+ return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
+ }
+
+ /**
+ * Performs Runge Kutta integration (4th order).
+ *
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dtSeconds The time over which to integrate.
+ * @return The result of Runge Kutta integration (4th order).
+ */
+ @SuppressWarnings("ParameterName")
+ public static double rk4(
+ BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) {
+ final var h = dtSeconds;
+
+ final var k1 = f.apply(x, u);
+ final var k2 = f.apply(x + h * k1 * 0.5, u);
+ final var k3 = f.apply(x + h * k2 * 0.5, u);
+ final var k4 = f.apply(x + h * k3, u);
+
+ return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
+ }
+
+ /**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
+ *
+ * @param <States> A Num representing the states of the system to integrate.
+ * @param <Inputs> A Num representing the inputs of the system to integrate.
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dtSeconds The time over which to integrate.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+ public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4(
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u,
+ double dtSeconds) {
+ final var h = dtSeconds;
+
+ Matrix<States, N1> k1 = f.apply(x, u);
+ Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
+ Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
+ Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u);
+
+ return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
+ }
+
+ /**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
+ *
+ * @param <States> A Num prepresenting the states of the system.
+ * @param f The function to integrate. It must take one argument x.
+ * @param x The initial value of x.
+ * @param dtSeconds The time over which to integrate.
+ * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
+ */
+ @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+ public static <States extends Num> Matrix<States, N1> rk4(
+ Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) {
+ final var h = dtSeconds;
+
+ Matrix<States, N1> k1 = f.apply(x);
+ Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)));
+ Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)));
+ Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)));
+
+ return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
+ }
+
+ /**
+ * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described in
+ * https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method. By default, the max
+ * error is 1e-6.
+ *
+ * @param <States> A Num representing the states of the system to integrate.
+ * @param <Inputs> A Num representing the inputs of the system to integrate.
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dtSeconds The time over which to integrate.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkf45(
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u,
+ double dtSeconds) {
+ return rkf45(f, x, u, dtSeconds, 1e-6);
+ }
+
+ /**
+ * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described in
+ * https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
+ *
+ * @param <States> A Num representing the states of the system to integrate.
+ * @param <Inputs> A Num representing the inputs of the system to integrate.
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dtSeconds The time over which to integrate.
+ * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkf45(
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u,
+ double dtSeconds,
+ double maxError) {
+ // See
+ // https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
+ // for the Butcher tableau the following arrays came from.
+
+ // final double[5][5]
+ final double[][] A = {
+ {1.0 / 4.0},
+ {3.0 / 32.0, 9.0 / 32.0},
+ {1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0},
+ {439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0},
+ {-8.0 / 27.0, 2.0, -3544.0 / 2565.0, 1859.0 / 4104.0, -11.0 / 40.0}
+ };
+
+ // final double[6]
+ final double[] b1 = {
+ 16.0 / 135.0, 0.0, 6656.0 / 12825.0, 28561.0 / 56430.0, -9.0 / 50.0, 2.0 / 55.0
+ };
+
+ // final double[6]
+ final double[] b2 = {25.0 / 216.0, 0.0, 1408.0 / 2565.0, 2197.0 / 4104.0, -1.0 / 5.0, 0.0};
+
+ Matrix<States, N1> newX;
+ double truncationError;
+
+ double dtElapsed = 0.0;
+ double h = dtSeconds;
+
+ // Loop until we've gotten to our desired dt
+ while (dtElapsed < dtSeconds) {
+ do {
+ // Only allow us to advance up to the dt remaining
+ h = Math.min(h, dtSeconds - dtElapsed);
+
+ // Notice how the derivative in the Wikipedia notation is dy/dx.
+ // That means their y is our x and their x is our t
+ var k1 = f.apply(x, u);
+ var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
+ var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
+ var k4 =
+ f.apply(
+ x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
+ u);
+ var k5 =
+ f.apply(
+ x.plus(
+ k1.times(A[3][0])
+ .plus(k2.times(A[3][1]))
+ .plus(k3.times(A[3][2]))
+ .plus(k4.times(A[3][3]))
+ .times(h)),
+ u);
+ var k6 =
+ f.apply(
+ x.plus(
+ k1.times(A[4][0])
+ .plus(k2.times(A[4][1]))
+ .plus(k3.times(A[4][2]))
+ .plus(k4.times(A[4][3]))
+ .plus(k5.times(A[4][4]))
+ .times(h)),
+ u);
+
+ newX =
+ x.plus(
+ k1.times(b1[0])
+ .plus(k2.times(b1[1]))
+ .plus(k3.times(b1[2]))
+ .plus(k4.times(b1[3]))
+ .plus(k5.times(b1[4]))
+ .plus(k6.times(b1[5]))
+ .times(h));
+ truncationError =
+ (k1.times(b1[0] - b2[0])
+ .plus(k2.times(b1[1] - b2[1]))
+ .plus(k3.times(b1[2] - b2[2]))
+ .plus(k4.times(b1[3] - b2[3]))
+ .plus(k5.times(b1[4] - b2[4]))
+ .plus(k6.times(b1[5] - b2[5]))
+ .times(h))
+ .normF();
+
+ h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
+ } while (truncationError > maxError);
+
+ dtElapsed += h;
+ x = newX;
+ }
+
+ return x;
+ }
+
+ /**
+ * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max
+ * error is 1e-6.
+ *
+ * @param <States> A Num representing the states of the system to integrate.
+ * @param <Inputs> A Num representing the inputs of the system to integrate.
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dtSeconds The time over which to integrate.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u,
+ double dtSeconds) {
+ return rkdp(f, x, u, dtSeconds, 1e-6);
+ }
+
+ /**
+ * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
+ *
+ * @param <States> A Num representing the states of the system to integrate.
+ * @param <Inputs> A Num representing the inputs of the system to integrate.
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dtSeconds The time over which to integrate.
+ * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp(
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u,
+ double dtSeconds,
+ double maxError) {
+ // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
+ // Butcher tableau the following arrays came from.
+
+ // final double[6][6]
+ final double[][] A = {
+ {1.0 / 5.0},
+ {3.0 / 40.0, 9.0 / 40.0},
+ {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
+ {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
+ {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
+ {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}
+ };
+
+ // final double[7]
+ final double[] b1 = {
+ 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0
+ };
+
+ // final double[7]
+ final double[] b2 = {
+ 5179.0 / 57600.0,
+ 0.0,
+ 7571.0 / 16695.0,
+ 393.0 / 640.0,
+ -92097.0 / 339200.0,
+ 187.0 / 2100.0,
+ 1.0 / 40.0
+ };
+
+ Matrix<States, N1> newX;
+ double truncationError;
+
+ double dtElapsed = 0.0;
+ double h = dtSeconds;
+
+ // Loop until we've gotten to our desired dt
+ while (dtElapsed < dtSeconds) {
+ do {
+ // Only allow us to advance up to the dt remaining
+ h = Math.min(h, dtSeconds - dtElapsed);
+
+ var k1 = f.apply(x, u);
+ var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u);
+ var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u);
+ var k4 =
+ f.apply(
+ x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)),
+ u);
+ var k5 =
+ f.apply(
+ x.plus(
+ k1.times(A[3][0])
+ .plus(k2.times(A[3][1]))
+ .plus(k3.times(A[3][2]))
+ .plus(k4.times(A[3][3]))
+ .times(h)),
+ u);
+ var k6 =
+ f.apply(
+ x.plus(
+ k1.times(A[4][0])
+ .plus(k2.times(A[4][1]))
+ .plus(k3.times(A[4][2]))
+ .plus(k4.times(A[4][3]))
+ .plus(k5.times(A[4][4]))
+ .times(h)),
+ u);
+
+ // Since the final row of A and the array b1 have the same coefficients
+ // and k7 has no effect on newX, we can reuse the calculation.
+ newX =
+ x.plus(
+ k1.times(A[5][0])
+ .plus(k2.times(A[5][1]))
+ .plus(k3.times(A[5][2]))
+ .plus(k4.times(A[5][3]))
+ .plus(k5.times(A[5][4]))
+ .plus(k6.times(A[5][5]))
+ .times(h));
+ var k7 = f.apply(newX, u);
+
+ truncationError =
+ (k1.times(b1[0] - b2[0])
+ .plus(k2.times(b1[1] - b2[1]))
+ .plus(k3.times(b1[2] - b2[2]))
+ .plus(k4.times(b1[3] - b2[3]))
+ .plus(k5.times(b1[4] - b2[4]))
+ .plus(k6.times(b1[5] - b2[5]))
+ .plus(k7.times(b1[6] - b2[6]))
+ .times(h))
+ .normF();
+
+ h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0);
+ } while (truncationError > maxError);
+
+ dtElapsed += h;
+ x = newX;
+ }
+
+ return x;
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/NumericalJacobian.java b/wpimath/src/main/java/edu/wpi/first/math/system/NumericalJacobian.java
new file mode 100644
index 0000000..6c2c896
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/NumericalJacobian.java
@@ -0,0 +1,104 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.Num;
+import edu.wpi.first.math.numbers.N1;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+public final class NumericalJacobian {
+ private NumericalJacobian() {
+ // Utility Class.
+ }
+
+ private static final double kEpsilon = 1e-5;
+
+ /**
+ * Computes the numerical Jacobian with respect to x for f(x).
+ *
+ * @param <Rows> Number of rows in the result of f(x).
+ * @param <States> Num representing the number of rows in the output of f.
+ * @param <Cols> Number of columns in the result of f(x).
+ * @param rows Number of rows in the result of f(x).
+ * @param cols Number of columns in the result of f(x).
+ * @param f Vector-valued function from which to compute the Jacobian.
+ * @param x Vector argument.
+ * @return The numerical Jacobian with respect to x for f(x, u, ...).
+ */
+ @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
+ public static <Rows extends Num, Cols extends Num, States extends Num>
+ Matrix<Rows, Cols> numericalJacobian(
+ Nat<Rows> rows,
+ Nat<Cols> cols,
+ Function<Matrix<Cols, N1>, Matrix<States, N1>> f,
+ Matrix<Cols, N1> x) {
+ var result = new Matrix<>(rows, cols);
+
+ for (int i = 0; i < cols.getNum(); i++) {
+ var dxPlus = x.copy();
+ var dxMinus = x.copy();
+ dxPlus.set(i, 0, dxPlus.get(i, 0) + kEpsilon);
+ dxMinus.set(i, 0, dxMinus.get(i, 0) - kEpsilon);
+ @SuppressWarnings("LocalVariableName")
+ var dF = f.apply(dxPlus).minus(f.apply(dxMinus)).div(2 * kEpsilon);
+
+ result.setColumn(i, Matrix.changeBoundsUnchecked(dF));
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns numerical Jacobian with respect to x for f(x, u, ...).
+ *
+ * @param <Rows> Number of rows in the result of f(x, u).
+ * @param <States> Number of rows in x.
+ * @param <Inputs> Number of rows in the second input to f.
+ * @param <Outputs> Num representing the rows in the output of f.
+ * @param rows Number of rows in the result of f(x, u).
+ * @param states Number of rows in x.
+ * @param f Vector-valued function from which to compute Jacobian.
+ * @param x State vector.
+ * @param u Input vector.
+ * @return The numerical Jacobian with respect to x for f(x, u, ...).
+ */
+ @SuppressWarnings({"LambdaParameterName", "MethodTypeParameterName"})
+ public static <Rows extends Num, States extends Num, Inputs extends Num, Outputs extends Num>
+ Matrix<Rows, States> numericalJacobianX(
+ Nat<Rows> rows,
+ Nat<States> states,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u) {
+ return numericalJacobian(rows, states, _x -> f.apply(_x, u), x);
+ }
+
+ /**
+ * Returns the numerical Jacobian with respect to u for f(x, u).
+ *
+ * @param <States> The states of the system.
+ * @param <Inputs> The inputs to the system.
+ * @param <Rows> Number of rows in the result of f(x, u).
+ * @param rows Number of rows in the result of f(x, u).
+ * @param inputs Number of rows in u.
+ * @param f Vector-valued function from which to compute the Jacobian.
+ * @param x State vector.
+ * @param u Input vector.
+ * @return the numerical Jacobian with respect to u for f(x, u).
+ */
+ @SuppressWarnings({"LambdaParameterName", "MethodTypeParameterName"})
+ public static <Rows extends Num, States extends Num, Inputs extends Num>
+ Matrix<Rows, Inputs> numericalJacobianU(
+ Nat<Rows> rows,
+ Nat<Inputs> inputs,
+ BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
+ Matrix<States, N1> x,
+ Matrix<Inputs, N1> u) {
+ return numericalJacobian(rows, inputs, _u -> f.apply(x, _u), u);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/plant/DCMotor.java b/wpimath/src/main/java/edu/wpi/first/math/system/plant/DCMotor.java
new file mode 100644
index 0000000..94c117f
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/plant/DCMotor.java
@@ -0,0 +1,207 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system.plant;
+
+import edu.wpi.first.math.util.Units;
+
+/** Holds the constants for a DC motor. */
+public class DCMotor {
+ @SuppressWarnings("MemberName")
+ public final double nominalVoltageVolts;
+
+ @SuppressWarnings("MemberName")
+ public final double stallTorqueNewtonMeters;
+
+ @SuppressWarnings("MemberName")
+ public final double stallCurrentAmps;
+
+ @SuppressWarnings("MemberName")
+ public final double freeCurrentAmps;
+
+ @SuppressWarnings("MemberName")
+ public final double freeSpeedRadPerSec;
+
+ @SuppressWarnings("MemberName")
+ public final double rOhms;
+
+ @SuppressWarnings("MemberName")
+ public final double KvRadPerSecPerVolt;
+
+ @SuppressWarnings("MemberName")
+ public final double KtNMPerAmp;
+
+ /**
+ * Constructs a DC motor.
+ *
+ * @param nominalVoltageVolts Voltage at which the motor constants were measured.
+ * @param stallTorqueNewtonMeters Current draw when stalled.
+ * @param stallCurrentAmps Current draw when stalled.
+ * @param freeCurrentAmps Current draw under no load.
+ * @param freeSpeedRadPerSec Angular velocity under no load.
+ * @param numMotors Number of motors in a gearbox.
+ */
+ public DCMotor(
+ double nominalVoltageVolts,
+ double stallTorqueNewtonMeters,
+ double stallCurrentAmps,
+ double freeCurrentAmps,
+ double freeSpeedRadPerSec,
+ int numMotors) {
+ this.nominalVoltageVolts = nominalVoltageVolts;
+ this.stallTorqueNewtonMeters = stallTorqueNewtonMeters * numMotors;
+ this.stallCurrentAmps = stallCurrentAmps * numMotors;
+ this.freeCurrentAmps = freeCurrentAmps * numMotors;
+ this.freeSpeedRadPerSec = freeSpeedRadPerSec;
+
+ this.rOhms = nominalVoltageVolts / this.stallCurrentAmps;
+ this.KvRadPerSecPerVolt =
+ freeSpeedRadPerSec / (nominalVoltageVolts - rOhms * this.freeCurrentAmps);
+ this.KtNMPerAmp = this.stallTorqueNewtonMeters / this.stallCurrentAmps;
+ }
+
+ /**
+ * Estimate the current being drawn by this motor.
+ *
+ * @param speedRadiansPerSec The speed of the rotor.
+ * @param voltageInputVolts The input voltage.
+ * @return The estimated current.
+ */
+ public double getCurrent(double speedRadiansPerSec, double voltageInputVolts) {
+ return -1.0 / KvRadPerSecPerVolt / rOhms * speedRadiansPerSec + 1.0 / rOhms * voltageInputVolts;
+ }
+
+ /**
+ * Return a gearbox of CIM motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of CIM motors.
+ */
+ public static DCMotor getCIM(int numMotors) {
+ return new DCMotor(
+ 12, 2.42, 133, 2.7, Units.rotationsPerMinuteToRadiansPerSecond(5310), numMotors);
+ }
+
+ /**
+ * Return a gearbox of 775Pro motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of 775Pro motors.
+ */
+ public static DCMotor getVex775Pro(int numMotors) {
+ return new DCMotor(
+ 12, 0.71, 134, 0.7, Units.rotationsPerMinuteToRadiansPerSecond(18730), numMotors);
+ }
+
+ /**
+ * Return a gearbox of NEO motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of NEO motors.
+ */
+ public static DCMotor getNEO(int numMotors) {
+ return new DCMotor(
+ 12, 2.6, 105, 1.8, Units.rotationsPerMinuteToRadiansPerSecond(5676), numMotors);
+ }
+
+ /**
+ * Return a gearbox of MiniCIM motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of MiniCIM motors.
+ */
+ public static DCMotor getMiniCIM(int numMotors) {
+ return new DCMotor(
+ 12, 1.41, 89, 3, Units.rotationsPerMinuteToRadiansPerSecond(5840), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Bag motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Bag motors.
+ */
+ public static DCMotor getBag(int numMotors) {
+ return new DCMotor(
+ 12, 0.43, 53, 1.8, Units.rotationsPerMinuteToRadiansPerSecond(13180), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Andymark RS775-125 motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Andymark RS775-125 motors.
+ */
+ public static DCMotor getAndymarkRs775_125(int numMotors) {
+ return new DCMotor(
+ 12, 0.28, 18, 1.6, Units.rotationsPerMinuteToRadiansPerSecond(5800.0), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Banebots RS775 motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Banebots RS775 motors.
+ */
+ public static DCMotor getBanebotsRs775(int numMotors) {
+ return new DCMotor(
+ 12, 0.72, 97, 2.7, Units.rotationsPerMinuteToRadiansPerSecond(13050.0), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Andymark 9015 motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Andymark 9015 motors.
+ */
+ public static DCMotor getAndymark9015(int numMotors) {
+ return new DCMotor(
+ 12, 0.36, 71, 3.7, Units.rotationsPerMinuteToRadiansPerSecond(14270.0), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Banebots RS 550 motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Banebots RS 550 motors.
+ */
+ public static DCMotor getBanebotsRs550(int numMotors) {
+ return new DCMotor(
+ 12, 0.38, 84, 0.4, Units.rotationsPerMinuteToRadiansPerSecond(19000.0), numMotors);
+ }
+
+ /**
+ * Return a gearbox of NEO 550 motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of NEO 550 motors.
+ */
+ public static DCMotor getNeo550(int numMotors) {
+ return new DCMotor(
+ 12, 0.97, 100, 1.4, Units.rotationsPerMinuteToRadiansPerSecond(11000.0), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Falcon 500 motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Falcon 500 motors.
+ */
+ public static DCMotor getFalcon500(int numMotors) {
+ return new DCMotor(
+ 12, 4.69, 257, 1.5, Units.rotationsPerMinuteToRadiansPerSecond(6380.0), numMotors);
+ }
+
+ /**
+ * Return a gearbox of Romi/TI_RSLK MAX motors.
+ *
+ * @param numMotors Number of motors in the gearbox.
+ * @return A gearbox of Romi/TI_RSLK MAX motors.
+ */
+ public static DCMotor getRomiBuiltIn(int numMotors) {
+ // From https://www.pololu.com/product/1520/specs
+ return new DCMotor(
+ 4.5, 0.1765, 1.25, 0.13, Units.rotationsPerMinuteToRadiansPerSecond(150.0), numMotors);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/system/plant/LinearSystemId.java b/wpimath/src/main/java/edu/wpi/first/math/system/plant/LinearSystemId.java
new file mode 100644
index 0000000..2933e93
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/system/plant/LinearSystemId.java
@@ -0,0 +1,331 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system.plant;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.system.LinearSystem;
+
+public final class LinearSystemId {
+ private LinearSystemId() {
+ // Utility class
+ }
+
+ /**
+ * Create a state-space model of an elevator system. The states of the system are [position,
+ * velocity]ᵀ, inputs are [voltage], and outputs are [position].
+ *
+ * @param motor The motor (or gearbox) attached to the arm.
+ * @param massKg The mass of the elevator carriage, in kilograms.
+ * @param radiusMeters The radius of thd driving drum of the elevator, in meters.
+ * @param G The reduction between motor and drum, as a ratio of output to input.
+ * @return A LinearSystem representing the given characterized constants.
+ * @throws IllegalArgumentException if massKg <= 0, radiusMeters <= 0, or G <= 0.
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N2, N1, N1> createElevatorSystem(
+ DCMotor motor, double massKg, double radiusMeters, double G) {
+ if (massKg <= 0.0) {
+ throw new IllegalArgumentException("massKg must be greater than zero.");
+ }
+ if (radiusMeters <= 0.0) {
+ throw new IllegalArgumentException("radiusMeters must be greater than zero.");
+ }
+ if (G <= 0) {
+ throw new IllegalArgumentException("G must be greater than zero.");
+ }
+
+ return new LinearSystem<>(
+ Matrix.mat(Nat.N2(), Nat.N2())
+ .fill(
+ 0,
+ 1,
+ 0,
+ -Math.pow(G, 2)
+ * motor.KtNMPerAmp
+ / (motor.rOhms
+ * radiusMeters
+ * radiusMeters
+ * massKg
+ * motor.KvRadPerSecPerVolt)),
+ VecBuilder.fill(0, G * motor.KtNMPerAmp / (motor.rOhms * radiusMeters * massKg)),
+ Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0),
+ new Matrix<>(Nat.N1(), Nat.N1()));
+ }
+
+ /**
+ * Create a state-space model of a flywheel system. The states of the system are [angular
+ * velocity], inputs are [voltage], and outputs are [angular velocity].
+ *
+ * @param motor The motor (or gearbox) attached to the arm.
+ * @param jKgMetersSquared The moment of inertia J of the flywheel.
+ * @param G The reduction between motor and drum, as a ratio of output to input.
+ * @return A LinearSystem representing the given characterized constants.
+ * @throws IllegalArgumentException if jKgMetersSquared <= 0 or G <= 0.
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N1, N1, N1> createFlywheelSystem(
+ DCMotor motor, double jKgMetersSquared, double G) {
+ if (jKgMetersSquared <= 0.0) {
+ throw new IllegalArgumentException("J must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw new IllegalArgumentException("G must be greater than zero.");
+ }
+
+ return new LinearSystem<>(
+ VecBuilder.fill(
+ -G
+ * G
+ * motor.KtNMPerAmp
+ / (motor.KvRadPerSecPerVolt * motor.rOhms * jKgMetersSquared)),
+ VecBuilder.fill(G * motor.KtNMPerAmp / (motor.rOhms * jKgMetersSquared)),
+ Matrix.eye(Nat.N1()),
+ new Matrix<>(Nat.N1(), Nat.N1()));
+ }
+
+ /**
+ * Create a state-space model of a differential drive drivetrain. In this model, the states are
+ * [v_left, v_right]ᵀ, inputs are [V_left, V_right]ᵀ and outputs are [v_left, v_right]ᵀ.
+ *
+ * @param motor the gearbox representing the motors driving the drivetrain.
+ * @param massKg the mass of the robot.
+ * @param rMeters the radius of the wheels in meters.
+ * @param rbMeters the radius of the base (half the track width) in meters.
+ * @param JKgMetersSquared the moment of inertia of the robot.
+ * @param G the gearing reduction as output over input.
+ * @return A LinearSystem representing a differential drivetrain.
+ * @throws IllegalArgumentException if m <= 0, r <= 0, rb <= 0, J <= 0, or G <= 0.
+ */
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ public static LinearSystem<N2, N2, N2> createDrivetrainVelocitySystem(
+ DCMotor motor,
+ double massKg,
+ double rMeters,
+ double rbMeters,
+ double JKgMetersSquared,
+ double G) {
+ if (massKg <= 0.0) {
+ throw new IllegalArgumentException("massKg must be greater than zero.");
+ }
+ if (rMeters <= 0.0) {
+ throw new IllegalArgumentException("rMeters must be greater than zero.");
+ }
+ if (rbMeters <= 0.0) {
+ throw new IllegalArgumentException("rbMeters must be greater than zero.");
+ }
+ if (JKgMetersSquared <= 0.0) {
+ throw new IllegalArgumentException("JKgMetersSquared must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw new IllegalArgumentException("G must be greater than zero.");
+ }
+
+ var C1 =
+ -(G * G) * motor.KtNMPerAmp / (motor.KvRadPerSecPerVolt * motor.rOhms * rMeters * rMeters);
+ var C2 = G * motor.KtNMPerAmp / (motor.rOhms * rMeters);
+
+ final double C3 = 1 / massKg + rbMeters * rbMeters / JKgMetersSquared;
+ final double C4 = 1 / massKg - rbMeters * rbMeters / JKgMetersSquared;
+ var A = Matrix.mat(Nat.N2(), Nat.N2()).fill(C3 * C1, C4 * C1, C4 * C1, C3 * C1);
+ var B = Matrix.mat(Nat.N2(), Nat.N2()).fill(C3 * C2, C4 * C2, C4 * C2, C3 * C2);
+ var C = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 0.0, 0.0, 1.0);
+ var D = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 0.0, 0.0, 0.0);
+
+ return new LinearSystem<>(A, B, C, D);
+ }
+
+ /**
+ * Create a state-space model of a single jointed arm system. The states of the system are [angle,
+ * angular velocity], inputs are [voltage], and outputs are [angle].
+ *
+ * @param motor The motor (or gearbox) attached to the arm.
+ * @param jKgSquaredMeters The moment of inertia J of the arm.
+ * @param G The gearing between the motor and arm, in output over input. Most of the time this
+ * will be greater than 1.
+ * @return A LinearSystem representing the given characterized constants.
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N2, N1, N1> createSingleJointedArmSystem(
+ DCMotor motor, double jKgSquaredMeters, double G) {
+ if (jKgSquaredMeters <= 0.0) {
+ throw new IllegalArgumentException("jKgSquaredMeters must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw new IllegalArgumentException("G must be greater than zero.");
+ }
+
+ return new LinearSystem<>(
+ Matrix.mat(Nat.N2(), Nat.N2())
+ .fill(
+ 0,
+ 1,
+ 0,
+ -Math.pow(G, 2)
+ * motor.KtNMPerAmp
+ / (motor.KvRadPerSecPerVolt * motor.rOhms * jKgSquaredMeters)),
+ VecBuilder.fill(0, G * motor.KtNMPerAmp / (motor.rOhms * jKgSquaredMeters)),
+ Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0),
+ new Matrix<>(Nat.N1(), Nat.N1()));
+ }
+
+ /**
+ * Identify a velocity system from it's kV (volts/(unit/sec)) and kA (volts/(unit/sec^2). These
+ * constants cam be found using SysId. The states of the system are [velocity], inputs are
+ * [voltage], and outputs are [velocity].
+ *
+ * <p>The distance unit you choose MUST be an SI unit (i.e. meters or radians). You can use the
+ * {@link edu.wpi.first.math.util.Units} class for converting between unit types.
+ *
+ * @param kV The velocity gain, in volts per (units per second)
+ * @param kA The acceleration gain, in volts per (units per second squared)
+ * @return A LinearSystem representing the given characterized constants.
+ * @throws IllegalArgumentException if kV <= 0 or kA <= 0.
+ * @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N1, N1, N1> identifyVelocitySystem(double kV, double kA) {
+ if (kV <= 0.0) {
+ throw new IllegalArgumentException("Kv must be greater than zero.");
+ }
+ if (kA <= 0.0) {
+ throw new IllegalArgumentException("Ka must be greater than zero.");
+ }
+
+ return new LinearSystem<>(
+ VecBuilder.fill(-kV / kA),
+ VecBuilder.fill(1.0 / kA),
+ VecBuilder.fill(1.0),
+ VecBuilder.fill(0.0));
+ }
+
+ /**
+ * Identify a position system from it's kV (volts/(unit/sec)) and kA (volts/(unit/sec^2). These
+ * constants cam be found using SysId. The states of the system are [position, velocity]ᵀ, inputs
+ * are [voltage], and outputs are [position].
+ *
+ * <p>The distance unit you choose MUST be an SI unit (i.e. meters or radians). You can use the
+ * {@link edu.wpi.first.math.util.Units} class for converting between unit types.
+ *
+ * @param kV The velocity gain, in volts per (units per second)
+ * @param kA The acceleration gain, in volts per (units per second squared)
+ * @return A LinearSystem representing the given characterized constants.
+ * @throws IllegalArgumentException if kV <= 0 or kA <= 0.
+ * @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N2, N1, N1> identifyPositionSystem(double kV, double kA) {
+ if (kV <= 0.0) {
+ throw new IllegalArgumentException("Kv must be greater than zero.");
+ }
+ if (kA <= 0.0) {
+ throw new IllegalArgumentException("Ka must be greater than zero.");
+ }
+
+ return new LinearSystem<>(
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 1.0, 0.0, -kV / kA),
+ VecBuilder.fill(0.0, 1.0 / kA),
+ Matrix.mat(Nat.N1(), Nat.N2()).fill(1.0, 0.0),
+ VecBuilder.fill(0.0));
+ }
+
+ /**
+ * Identify a standard differential drive drivetrain, given the drivetrain's kV and kA in both
+ * linear (volts/(meter/sec) and volts/(meter/sec^2)) and angular (volts/(meter/sec) and
+ * volts/(meter/sec^2)) cases. This can be found using SysId. The states of the system are [left
+ * velocity, right velocity]ᵀ, inputs are [left voltage, right voltage]ᵀ, and outputs are [left
+ * velocity, right velocity]ᵀ.
+ *
+ * @param kVLinear The linear velocity gain, volts per (meter per second).
+ * @param kALinear The linear acceleration gain, volts per (meter per second squared).
+ * @param kVAngular The angular velocity gain, volts per (meter per second).
+ * @param kAAngular The angular acceleration gain, volts per (meter per second squared).
+ * @return A LinearSystem representing the given characterized constants.
+ * @throws IllegalArgumentException if kVLinear <= 0, kALinear <= 0, kVAngular <= 0, or
+ * kAAngular <= 0.
+ * @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N2, N2, N2> identifyDrivetrainSystem(
+ double kVLinear, double kALinear, double kVAngular, double kAAngular) {
+ if (kVLinear <= 0.0) {
+ throw new IllegalArgumentException("Kv,linear must be greater than zero.");
+ }
+ if (kALinear <= 0.0) {
+ throw new IllegalArgumentException("Ka,linear must be greater than zero.");
+ }
+ if (kVAngular <= 0.0) {
+ throw new IllegalArgumentException("Kv,angular must be greater than zero.");
+ }
+ if (kAAngular <= 0.0) {
+ throw new IllegalArgumentException("Ka,angular must be greater than zero.");
+ }
+
+ final double A1 = 0.5 * -(kVLinear / kALinear + kVAngular / kAAngular);
+ final double A2 = 0.5 * -(kVLinear / kALinear - kVAngular / kAAngular);
+ final double B1 = 0.5 * (1.0 / kALinear + 1.0 / kAAngular);
+ final double B2 = 0.5 * (1.0 / kALinear - 1.0 / kAAngular);
+
+ return new LinearSystem<>(
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(A1, A2, A2, A1),
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(B1, B2, B2, B1),
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1),
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 0, 0, 0));
+ }
+
+ /**
+ * Identify a standard differential drive drivetrain, given the drivetrain's kV and kA in both
+ * linear (volts/(meter/sec) and volts/(meter/sec^2)) and angular (volts/(radian/sec) and
+ * volts/(radian/sec^2)) cases. This can be found using SysId. The states of the system are [left
+ * velocity, right velocity]ᵀ, inputs are [left voltage, right voltage]ᵀ, and outputs are [left
+ * velocity, right velocity]ᵀ.
+ *
+ * @param kVLinear The linear velocity gain, volts per (meter per second).
+ * @param kALinear The linear acceleration gain, volts per (meter per second squared).
+ * @param kVAngular The angular velocity gain, volts per (radians per second).
+ * @param kAAngular The angular acceleration gain, volts per (radians per second squared).
+ * @param trackwidth The width of the drivetrain in meters.
+ * @return A LinearSystem representing the given characterized constants.
+ * @throws IllegalArgumentException if kVLinear <= 0, kALinear <= 0, kVAngular <= 0,
+ * kAAngular <= 0, or trackwidth <= 0.
+ * @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
+ */
+ @SuppressWarnings("ParameterName")
+ public static LinearSystem<N2, N2, N2> identifyDrivetrainSystem(
+ double kVLinear, double kALinear, double kVAngular, double kAAngular, double trackwidth) {
+ if (kVLinear <= 0.0) {
+ throw new IllegalArgumentException("Kv,linear must be greater than zero.");
+ }
+ if (kALinear <= 0.0) {
+ throw new IllegalArgumentException("Ka,linear must be greater than zero.");
+ }
+ if (kVAngular <= 0.0) {
+ throw new IllegalArgumentException("Kv,angular must be greater than zero.");
+ }
+ if (kAAngular <= 0.0) {
+ throw new IllegalArgumentException("Ka,angular must be greater than zero.");
+ }
+ if (trackwidth <= 0.0) {
+ throw new IllegalArgumentException("trackwidth must be greater than zero.");
+ }
+
+ // We want to find a factor to include in Kv,angular that will convert
+ // `u = Kv,angular omega` to `u = Kv,angular v`.
+ //
+ // v = omega r
+ // omega = v/r
+ // omega = 1/r v
+ // omega = 1/(trackwidth/2) v
+ // omega = 2/trackwidth v
+ //
+ // So multiplying by 2/trackwidth converts the angular gains from V/(rad/s)
+ // to V/m/s).
+ return identifyDrivetrainSystem(
+ kVLinear, kALinear, kVAngular * 2.0 / trackwidth, kAAngular * 2.0 / trackwidth);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/Trajectory.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/Trajectory.java
new file mode 100644
index 0000000..2ec4244
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/Trajectory.java
@@ -0,0 +1,409 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Transform2d;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+/**
+ * Represents a time-parameterized trajectory. The trajectory contains of various States that
+ * represent the pose, curvature, time elapsed, velocity, and acceleration at that point.
+ */
+public class Trajectory {
+ private final double m_totalTimeSeconds;
+ private final List<State> m_states;
+
+ /** Constructs an empty trajectory. */
+ public Trajectory() {
+ m_states = new ArrayList<>();
+ m_totalTimeSeconds = 0.0;
+ }
+
+ /**
+ * Constructs a trajectory from a vector of states.
+ *
+ * @param states A vector of states.
+ */
+ public Trajectory(final List<State> states) {
+ m_states = states;
+ m_totalTimeSeconds = m_states.get(m_states.size() - 1).timeSeconds;
+ }
+
+ /**
+ * Linearly interpolates between two values.
+ *
+ * @param startValue The start value.
+ * @param endValue The end value.
+ * @param t The fraction for interpolation.
+ * @return The interpolated value.
+ */
+ @SuppressWarnings("ParameterName")
+ private static double lerp(double startValue, double endValue, double t) {
+ return startValue + (endValue - startValue) * t;
+ }
+
+ /**
+ * Linearly interpolates between two poses.
+ *
+ * @param startValue The start pose.
+ * @param endValue The end pose.
+ * @param t The fraction for interpolation.
+ * @return The interpolated pose.
+ */
+ @SuppressWarnings("ParameterName")
+ private static Pose2d lerp(Pose2d startValue, Pose2d endValue, double t) {
+ return startValue.plus((endValue.minus(startValue)).times(t));
+ }
+
+ /**
+ * Returns the initial pose of the trajectory.
+ *
+ * @return The initial pose of the trajectory.
+ */
+ public Pose2d getInitialPose() {
+ return sample(0).poseMeters;
+ }
+
+ /**
+ * Returns the overall duration of the trajectory.
+ *
+ * @return The duration of the trajectory.
+ */
+ public double getTotalTimeSeconds() {
+ return m_totalTimeSeconds;
+ }
+
+ /**
+ * Return the states of the trajectory.
+ *
+ * @return The states of the trajectory.
+ */
+ public List<State> getStates() {
+ return m_states;
+ }
+
+ /**
+ * Sample the trajectory at a point in time.
+ *
+ * @param timeSeconds The point in time since the beginning of the trajectory to sample.
+ * @return The state at that point in time.
+ */
+ public State sample(double timeSeconds) {
+ if (timeSeconds <= m_states.get(0).timeSeconds) {
+ return m_states.get(0);
+ }
+ if (timeSeconds >= m_totalTimeSeconds) {
+ return m_states.get(m_states.size() - 1);
+ }
+
+ // To get the element that we want, we will use a binary search algorithm
+ // instead of iterating over a for-loop. A binary search is O(std::log(n))
+ // whereas searching using a loop is O(n).
+
+ // This starts at 1 because we use the previous state later on for
+ // interpolation.
+ int low = 1;
+ int high = m_states.size() - 1;
+
+ while (low != high) {
+ int mid = (low + high) / 2;
+ if (m_states.get(mid).timeSeconds < timeSeconds) {
+ // This index and everything under it are less than the requested
+ // timestamp. Therefore, we can discard them.
+ low = mid + 1;
+ } else {
+ // t is at least as large as the element at this index. This means that
+ // anything after it cannot be what we are looking for.
+ high = mid;
+ }
+ }
+
+ // High and Low should be the same.
+
+ // The sample's timestamp is now greater than or equal to the requested
+ // timestamp. If it is greater, we need to interpolate between the
+ // previous state and the current state to get the exact state that we
+ // want.
+ final State sample = m_states.get(low);
+ final State prevSample = m_states.get(low - 1);
+
+ // If the difference in states is negligible, then we are spot on!
+ if (Math.abs(sample.timeSeconds - prevSample.timeSeconds) < 1E-9) {
+ return sample;
+ }
+ // Interpolate between the two states for the state that we want.
+ return prevSample.interpolate(
+ sample,
+ (timeSeconds - prevSample.timeSeconds) / (sample.timeSeconds - prevSample.timeSeconds));
+ }
+
+ /**
+ * Transforms all poses in the trajectory by the given transform. This is useful for converting a
+ * robot-relative trajectory into a field-relative trajectory. This works with respect to the
+ * first pose in the trajectory.
+ *
+ * @param transform The transform to transform the trajectory by.
+ * @return The transformed trajectory.
+ */
+ public Trajectory transformBy(Transform2d transform) {
+ var firstState = m_states.get(0);
+ var firstPose = firstState.poseMeters;
+
+ // Calculate the transformed first pose.
+ var newFirstPose = firstPose.plus(transform);
+ List<State> newStates = new ArrayList<>();
+
+ newStates.add(
+ new State(
+ firstState.timeSeconds,
+ firstState.velocityMetersPerSecond,
+ firstState.accelerationMetersPerSecondSq,
+ newFirstPose,
+ firstState.curvatureRadPerMeter));
+
+ for (int i = 1; i < m_states.size(); i++) {
+ var state = m_states.get(i);
+ // We are transforming relative to the coordinate frame of the new initial pose.
+ newStates.add(
+ new State(
+ state.timeSeconds,
+ state.velocityMetersPerSecond,
+ state.accelerationMetersPerSecondSq,
+ newFirstPose.plus(state.poseMeters.minus(firstPose)),
+ state.curvatureRadPerMeter));
+ }
+
+ return new Trajectory(newStates);
+ }
+
+ /**
+ * Transforms all poses in the trajectory so that they are relative to the given pose. This is
+ * useful for converting a field-relative trajectory into a robot-relative trajectory.
+ *
+ * @param pose The pose that is the origin of the coordinate frame that the current trajectory
+ * will be transformed into.
+ * @return The transformed trajectory.
+ */
+ public Trajectory relativeTo(Pose2d pose) {
+ return new Trajectory(
+ m_states.stream()
+ .map(
+ state ->
+ new State(
+ state.timeSeconds,
+ state.velocityMetersPerSecond,
+ state.accelerationMetersPerSecondSq,
+ state.poseMeters.relativeTo(pose),
+ state.curvatureRadPerMeter))
+ .collect(Collectors.toList()));
+ }
+
+ /**
+ * Concatenates another trajectory to the current trajectory. The user is responsible for making
+ * sure that the end pose of this trajectory and the start pose of the other trajectory match (if
+ * that is the desired behavior).
+ *
+ * @param other The trajectory to concatenate.
+ * @return The concatenated trajectory.
+ */
+ public Trajectory concatenate(Trajectory other) {
+ // If this is a default constructed trajectory with no states, then we can
+ // simply return the rhs trajectory.
+ if (m_states.isEmpty()) {
+ return other;
+ }
+
+ // Deep copy the current states.
+ List<State> states =
+ m_states.stream()
+ .map(
+ state ->
+ new State(
+ state.timeSeconds,
+ state.velocityMetersPerSecond,
+ state.accelerationMetersPerSecondSq,
+ state.poseMeters,
+ state.curvatureRadPerMeter))
+ .collect(Collectors.toList());
+
+ // Here we omit the first state of the other trajectory because we don't want
+ // two time points with different states. Sample() will automatically
+ // interpolate between the end of this trajectory and the second state of the
+ // other trajectory.
+ for (int i = 1; i < other.getStates().size(); ++i) {
+ var s = other.getStates().get(i);
+ states.add(
+ new State(
+ s.timeSeconds + m_totalTimeSeconds,
+ s.velocityMetersPerSecond,
+ s.accelerationMetersPerSecondSq,
+ s.poseMeters,
+ s.curvatureRadPerMeter));
+ }
+ return new Trajectory(states);
+ }
+
+ /**
+ * Represents a time-parameterized trajectory. The trajectory contains of various States that
+ * represent the pose, curvature, time elapsed, velocity, and acceleration at that point.
+ */
+ @SuppressWarnings("MemberName")
+ public static class State {
+ // The time elapsed since the beginning of the trajectory.
+ @JsonProperty("time")
+ public double timeSeconds;
+
+ // The speed at that point of the trajectory.
+ @JsonProperty("velocity")
+ public double velocityMetersPerSecond;
+
+ // The acceleration at that point of the trajectory.
+ @JsonProperty("acceleration")
+ public double accelerationMetersPerSecondSq;
+
+ // The pose at that point of the trajectory.
+ @JsonProperty("pose")
+ public Pose2d poseMeters;
+
+ // The curvature at that point of the trajectory.
+ @JsonProperty("curvature")
+ public double curvatureRadPerMeter;
+
+ public State() {
+ poseMeters = new Pose2d();
+ }
+
+ /**
+ * Constructs a State with the specified parameters.
+ *
+ * @param timeSeconds The time elapsed since the beginning of the trajectory.
+ * @param velocityMetersPerSecond The speed at that point of the trajectory.
+ * @param accelerationMetersPerSecondSq The acceleration at that point of the trajectory.
+ * @param poseMeters The pose at that point of the trajectory.
+ * @param curvatureRadPerMeter The curvature at that point of the trajectory.
+ */
+ public State(
+ double timeSeconds,
+ double velocityMetersPerSecond,
+ double accelerationMetersPerSecondSq,
+ Pose2d poseMeters,
+ double curvatureRadPerMeter) {
+ this.timeSeconds = timeSeconds;
+ this.velocityMetersPerSecond = velocityMetersPerSecond;
+ this.accelerationMetersPerSecondSq = accelerationMetersPerSecondSq;
+ this.poseMeters = poseMeters;
+ this.curvatureRadPerMeter = curvatureRadPerMeter;
+ }
+
+ /**
+ * Interpolates between two States.
+ *
+ * @param endValue The end value for the interpolation.
+ * @param i The interpolant (fraction).
+ * @return The interpolated state.
+ */
+ @SuppressWarnings("ParameterName")
+ State interpolate(State endValue, double i) {
+ // Find the new t value.
+ final double newT = lerp(timeSeconds, endValue.timeSeconds, i);
+
+ // Find the delta time between the current state and the interpolated state.
+ final double deltaT = newT - timeSeconds;
+
+ // If delta time is negative, flip the order of interpolation.
+ if (deltaT < 0) {
+ return endValue.interpolate(this, 1 - i);
+ }
+
+ // Check whether the robot is reversing at this stage.
+ final boolean reversing =
+ velocityMetersPerSecond < 0
+ || Math.abs(velocityMetersPerSecond) < 1E-9 && accelerationMetersPerSecondSq < 0;
+
+ // Calculate the new velocity
+ // v_f = v_0 + at
+ final double newV = velocityMetersPerSecond + (accelerationMetersPerSecondSq * deltaT);
+
+ // Calculate the change in position.
+ // delta_s = v_0 t + 0.5 at^2
+ final double newS =
+ (velocityMetersPerSecond * deltaT
+ + 0.5 * accelerationMetersPerSecondSq * Math.pow(deltaT, 2))
+ * (reversing ? -1.0 : 1.0);
+
+ // Return the new state. To find the new position for the new state, we need
+ // to interpolate between the two endpoint poses. The fraction for
+ // interpolation is the change in position (delta s) divided by the total
+ // distance between the two endpoints.
+ final double interpolationFrac =
+ newS / endValue.poseMeters.getTranslation().getDistance(poseMeters.getTranslation());
+
+ return new State(
+ newT,
+ newV,
+ accelerationMetersPerSecondSq,
+ lerp(poseMeters, endValue.poseMeters, interpolationFrac),
+ lerp(curvatureRadPerMeter, endValue.curvatureRadPerMeter, interpolationFrac));
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "State(Sec: %.2f, Vel m/s: %.2f, Accel m/s/s: %.2f, Pose: %s, Curvature: %.2f)",
+ timeSeconds,
+ velocityMetersPerSecond,
+ accelerationMetersPerSecondSq,
+ poseMeters,
+ curvatureRadPerMeter);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof State)) {
+ return false;
+ }
+ State state = (State) obj;
+ return Double.compare(state.timeSeconds, timeSeconds) == 0
+ && Double.compare(state.velocityMetersPerSecond, velocityMetersPerSecond) == 0
+ && Double.compare(state.accelerationMetersPerSecondSq, accelerationMetersPerSecondSq) == 0
+ && Double.compare(state.curvatureRadPerMeter, curvatureRadPerMeter) == 0
+ && Objects.equals(poseMeters, state.poseMeters);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ timeSeconds,
+ velocityMetersPerSecond,
+ accelerationMetersPerSecondSq,
+ poseMeters,
+ curvatureRadPerMeter);
+ }
+ }
+
+ @Override
+ public String toString() {
+ String stateList = m_states.stream().map(State::toString).collect(Collectors.joining(", \n"));
+ return String.format("Trajectory - Seconds: %.2f, States:\n%s", m_totalTimeSeconds, stateList);
+ }
+
+ @Override
+ public int hashCode() {
+ return m_states.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof Trajectory && m_states.equals(((Trajectory) obj).getStates());
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryConfig.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryConfig.java
similarity index 68%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryConfig.java
rename to wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryConfig.java
index 6c9b56a..fbf734f 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryConfig.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryConfig.java
@@ -1,30 +1,26 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.trajectory;
+package edu.wpi.first.math.trajectory;
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+import edu.wpi.first.math.kinematics.MecanumDriveKinematics;
+import edu.wpi.first.math.kinematics.SwerveDriveKinematics;
+import edu.wpi.first.math.trajectory.constraint.DifferentialDriveKinematicsConstraint;
+import edu.wpi.first.math.trajectory.constraint.MecanumDriveKinematicsConstraint;
+import edu.wpi.first.math.trajectory.constraint.SwerveDriveKinematicsConstraint;
+import edu.wpi.first.math.trajectory.constraint.TrajectoryConstraint;
import java.util.ArrayList;
import java.util.List;
-import edu.wpi.first.wpilibj.kinematics.DifferentialDriveKinematics;
-import edu.wpi.first.wpilibj.kinematics.MecanumDriveKinematics;
-import edu.wpi.first.wpilibj.kinematics.SwerveDriveKinematics;
-import edu.wpi.first.wpilibj.trajectory.constraint.DifferentialDriveKinematicsConstraint;
-import edu.wpi.first.wpilibj.trajectory.constraint.MecanumDriveKinematicsConstraint;
-import edu.wpi.first.wpilibj.trajectory.constraint.SwerveDriveKinematicsConstraint;
-import edu.wpi.first.wpilibj.trajectory.constraint.TrajectoryConstraint;
-
/**
* Represents the configuration for generating a trajectory. This class stores the start velocity,
* end velocity, max velocity, max acceleration, custom constraints, and the reversed flag.
*
* <p>The class must be constructed with a max velocity and max acceleration. The other parameters
- * (start velocity, end velocity, constraints, reversed) have been defaulted to reasonable
- * values (0, 0, {}, false). These values can be changed via the setXXX methods.
+ * (start velocity, end velocity, constraints, reversed) have been defaulted to reasonable values
+ * (0, 0, {}, false). These values can be changed via the setXXX methods.
*/
public class TrajectoryConfig {
private final double m_maxVelocity;
@@ -37,11 +33,11 @@
/**
* Constructs the trajectory configuration class.
*
- * @param maxVelocityMetersPerSecond The max velocity for the trajectory.
+ * @param maxVelocityMetersPerSecond The max velocity for the trajectory.
* @param maxAccelerationMetersPerSecondSq The max acceleration for the trajectory.
*/
- public TrajectoryConfig(double maxVelocityMetersPerSecond,
- double maxAccelerationMetersPerSecondSq) {
+ public TrajectoryConfig(
+ double maxVelocityMetersPerSecond, double maxAccelerationMetersPerSecondSq) {
m_maxVelocity = maxVelocityMetersPerSecond;
m_maxAcceleration = maxAccelerationMetersPerSecondSq;
m_constraints = new ArrayList<>();
@@ -60,6 +56,7 @@
/**
* Adds all user-defined constraints from a list to the trajectory.
+ *
* @param constraints List of user-defined constraints.
* @return Instance of the current config object.
*/
@@ -69,8 +66,8 @@
}
/**
- * Adds a differential drive kinematics constraint to ensure that
- * no wheel velocity of a differential drive goes above the max velocity.
+ * Adds a differential drive kinematics constraint to ensure that no wheel velocity of a
+ * differential drive goes above the max velocity.
*
* @param kinematics The differential drive kinematics.
* @return Instance of the current config object.
@@ -81,34 +78,34 @@
}
/**
- * Adds a mecanum drive kinematics constraint to ensure that
- * no wheel velocity of a mecanum drive goes above the max velocity.
- *
- * @param kinematics The mecanum drive kinematics.
- * @return Instance of the current config object.
- */
+ * Adds a mecanum drive kinematics constraint to ensure that no wheel velocity of a mecanum drive
+ * goes above the max velocity.
+ *
+ * @param kinematics The mecanum drive kinematics.
+ * @return Instance of the current config object.
+ */
public TrajectoryConfig setKinematics(MecanumDriveKinematics kinematics) {
addConstraint(new MecanumDriveKinematicsConstraint(kinematics, m_maxVelocity));
return this;
}
/**
- * Adds a swerve drive kinematics constraint to ensure that
- * no wheel velocity of a swerve drive goes above the max velocity.
- *
- * @param kinematics The swerve drive kinematics.
- * @return Instance of the current config object.
- */
+ * Adds a swerve drive kinematics constraint to ensure that no wheel velocity of a swerve drive
+ * goes above the max velocity.
+ *
+ * @param kinematics The swerve drive kinematics.
+ * @return Instance of the current config object.
+ */
public TrajectoryConfig setKinematics(SwerveDriveKinematics kinematics) {
addConstraint(new SwerveDriveKinematicsConstraint(kinematics, m_maxVelocity));
return this;
}
/**
- * Returns the starting velocity of the trajectory.
- *
- * @return The starting velocity of the trajectory.
- */
+ * Returns the starting velocity of the trajectory.
+ *
+ * @return The starting velocity of the trajectory.
+ */
public double getStartVelocity() {
return m_startVelocity;
}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryGenerator.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryGenerator.java
similarity index 63%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryGenerator.java
rename to wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryGenerator.java
index 5e55c50..0827c15 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryGenerator.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryGenerator.java
@@ -1,39 +1,32 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.trajectory;
+package edu.wpi.first.math.trajectory;
+import edu.wpi.first.math.MathSharedStore;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Transform2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.spline.PoseWithCurvature;
+import edu.wpi.first.math.spline.Spline;
+import edu.wpi.first.math.spline.SplineHelper;
+import edu.wpi.first.math.spline.SplineParameterizer;
+import edu.wpi.first.math.spline.SplineParameterizer.MalformedSplineException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.function.BiConsumer;
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Transform2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-import edu.wpi.first.wpilibj.spline.PoseWithCurvature;
-import edu.wpi.first.wpilibj.spline.Spline;
-import edu.wpi.first.wpilibj.spline.SplineHelper;
-import edu.wpi.first.wpilibj.spline.SplineParameterizer;
-import edu.wpi.first.wpilibj.spline.SplineParameterizer.MalformedSplineException;
-
public final class TrajectoryGenerator {
private static final Trajectory kDoNothingTrajectory =
new Trajectory(Arrays.asList(new Trajectory.State()));
private static BiConsumer<String, StackTraceElement[]> errorFunc;
- /**
- * Private constructor because this is a utility class.
- */
- private TrajectoryGenerator() {
- }
+ /** Private constructor because this is a utility class. */
+ private TrajectoryGenerator() {}
private static void reportError(String error, StackTraceElement[] stackTrace) {
if (errorFunc != null) {
@@ -54,22 +47,21 @@
/**
* Generates a trajectory from the given control vectors and config. This method uses clamped
- * cubic splines -- a method in which the exterior control vectors and interior waypoints
- * are provided. The headings are automatically determined at the interior points to
- * ensure continuous curvature.
+ * cubic splines -- a method in which the exterior control vectors and interior waypoints are
+ * provided. The headings are automatically determined at the interior points to ensure continuous
+ * curvature.
*
- * @param initial The initial control vector.
+ * @param initial The initial control vector.
* @param interiorWaypoints The interior waypoints.
- * @param end The ending control vector.
- * @param config The configuration for the trajectory.
+ * @param end The ending control vector.
+ * @param config The configuration for the trajectory.
* @return The generated trajectory.
*/
public static Trajectory generateTrajectory(
Spline.ControlVector initial,
List<Translation2d> interiorWaypoints,
Spline.ControlVector end,
- TrajectoryConfig config
- ) {
+ TrajectoryConfig config) {
final var flip = new Transform2d(new Translation2d(), Rotation2d.fromDegrees(180.0));
// Clone the control vectors.
@@ -87,8 +79,10 @@
// Get the spline points
List<PoseWithCurvature> points;
try {
- points = splinePointsFromSplines(SplineHelper.getCubicSplinesFromControlVectors(newInitial,
- interiorWaypoints.toArray(new Translation2d[0]), newEnd));
+ points =
+ splinePointsFromSplines(
+ SplineHelper.getCubicSplinesFromControlVectors(
+ newInitial, interiorWaypoints.toArray(new Translation2d[0]), newEnd));
} catch (MalformedSplineException ex) {
reportError(ex.getMessage(), ex.getStackTrace());
return kDoNothingTrajectory;
@@ -103,49 +97,49 @@
}
// Generate and return trajectory.
- return TrajectoryParameterizer.timeParameterizeTrajectory(points, config.getConstraints(),
- config.getStartVelocity(), config.getEndVelocity(), config.getMaxVelocity(),
- config.getMaxAcceleration(), config.isReversed());
+ return TrajectoryParameterizer.timeParameterizeTrajectory(
+ points,
+ config.getConstraints(),
+ config.getStartVelocity(),
+ config.getEndVelocity(),
+ config.getMaxVelocity(),
+ config.getMaxAcceleration(),
+ config.isReversed());
}
/**
- * Generates a trajectory from the given waypoints and config. This method uses clamped
- * cubic splines -- a method in which the initial pose, final pose, and interior waypoints
- * are provided. The headings are automatically determined at the interior points to
- * ensure continuous curvature.
+ * Generates a trajectory from the given waypoints and config. This method uses clamped cubic
+ * splines -- a method in which the initial pose, final pose, and interior waypoints are provided.
+ * The headings are automatically determined at the interior points to ensure continuous
+ * curvature.
*
- * @param start The starting pose.
+ * @param start The starting pose.
* @param interiorWaypoints The interior waypoints.
- * @param end The ending pose.
- * @param config The configuration for the trajectory.
+ * @param end The ending pose.
+ * @param config The configuration for the trajectory.
* @return The generated trajectory.
*/
public static Trajectory generateTrajectory(
- Pose2d start, List<Translation2d> interiorWaypoints, Pose2d end,
- TrajectoryConfig config
- ) {
- var controlVectors = SplineHelper.getCubicControlVectorsFromWaypoints(
- start, interiorWaypoints.toArray(new Translation2d[0]), end
- );
+ Pose2d start, List<Translation2d> interiorWaypoints, Pose2d end, TrajectoryConfig config) {
+ var controlVectors =
+ SplineHelper.getCubicControlVectorsFromWaypoints(
+ start, interiorWaypoints.toArray(new Translation2d[0]), end);
// Return the generated trajectory.
return generateTrajectory(controlVectors[0], interiorWaypoints, controlVectors[1], config);
}
/**
- * Generates a trajectory from the given quintic control vectors and config. This method
- * uses quintic hermite splines -- therefore, all points must be represented by control
- * vectors. Continuous curvature is guaranteed in this method.
+ * Generates a trajectory from the given quintic control vectors and config. This method uses
+ * quintic hermite splines -- therefore, all points must be represented by control vectors.
+ * Continuous curvature is guaranteed in this method.
*
* @param controlVectors List of quintic control vectors.
- * @param config The configuration for the trajectory.
+ * @param config The configuration for the trajectory.
* @return The generated trajectory.
*/
- @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
public static Trajectory generateTrajectory(
- ControlVectorList controlVectors,
- TrajectoryConfig config
- ) {
+ ControlVectorList controlVectors, TrajectoryConfig config) {
final var flip = new Transform2d(new Translation2d(), Rotation2d.fromDegrees(180.0));
final var newControlVectors = new ArrayList<Spline.ControlVector>(controlVectors.size());
@@ -162,9 +156,10 @@
// Get the spline points
List<PoseWithCurvature> points;
try {
- points = splinePointsFromSplines(SplineHelper.getQuinticSplinesFromControlVectors(
- newControlVectors.toArray(new Spline.ControlVector[]{})
- ));
+ points =
+ splinePointsFromSplines(
+ SplineHelper.getQuinticSplinesFromControlVectors(
+ newControlVectors.toArray(new Spline.ControlVector[] {})));
} catch (MalformedSplineException ex) {
reportError(ex.getMessage(), ex.getStackTrace());
return kDoNothingTrajectory;
@@ -179,19 +174,23 @@
}
// Generate and return trajectory.
- return TrajectoryParameterizer.timeParameterizeTrajectory(points, config.getConstraints(),
- config.getStartVelocity(), config.getEndVelocity(), config.getMaxVelocity(),
- config.getMaxAcceleration(), config.isReversed());
-
+ return TrajectoryParameterizer.timeParameterizeTrajectory(
+ points,
+ config.getConstraints(),
+ config.getStartVelocity(),
+ config.getEndVelocity(),
+ config.getMaxVelocity(),
+ config.getMaxAcceleration(),
+ config.isReversed());
}
/**
- * Generates a trajectory from the given waypoints and config. This method
- * uses quintic hermite splines -- therefore, all points must be represented by Pose2d
- * objects. Continuous curvature is guaranteed in this method.
+ * Generates a trajectory from the given waypoints and config. This method uses quintic hermite
+ * splines -- therefore, all points must be represented by Pose2d objects. Continuous curvature is
+ * guaranteed in this method.
*
* @param waypoints List of waypoints..
- * @param config The configuration for the trajectory.
+ * @param config The configuration for the trajectory.
* @return The generated trajectory.
*/
@SuppressWarnings("LocalVariableName")
@@ -225,22 +224,25 @@
}
// Generate and return trajectory.
- return TrajectoryParameterizer.timeParameterizeTrajectory(points, config.getConstraints(),
- config.getStartVelocity(), config.getEndVelocity(), config.getMaxVelocity(),
- config.getMaxAcceleration(), config.isReversed());
+ return TrajectoryParameterizer.timeParameterizeTrajectory(
+ points,
+ config.getConstraints(),
+ config.getStartVelocity(),
+ config.getEndVelocity(),
+ config.getMaxVelocity(),
+ config.getMaxAcceleration(),
+ config.isReversed());
}
/**
- * Generate spline points from a vector of splines by parameterizing the
- * splines.
+ * Generate spline points from a vector of splines by parameterizing the splines.
*
* @param splines The splines to parameterize.
* @return The spline points for use in time parameterization of a trajectory.
* @throws MalformedSplineException When the spline is malformed (e.g. has close adjacent points
- * with approximately opposing headings)
+ * with approximately opposing headings)
*/
- public static List<PoseWithCurvature> splinePointsFromSplines(
- Spline[] splines) {
+ public static List<PoseWithCurvature> splinePointsFromSplines(Spline[] splines) {
// Create the vector of spline points.
var splinePoints = new ArrayList<PoseWithCurvature>();
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryParameterizer.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryParameterizer.java
new file mode 100644
index 0000000..d7fbf59
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryParameterizer.java
@@ -0,0 +1,329 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+/*
+ * MIT License
+ *
+ * Copyright (c) 2018 Team 254
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+package edu.wpi.first.math.trajectory;
+
+import edu.wpi.first.math.spline.PoseWithCurvature;
+import edu.wpi.first.math.trajectory.constraint.TrajectoryConstraint;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Class used to parameterize a trajectory by time. */
+public final class TrajectoryParameterizer {
+ /** Private constructor because this is a utility class. */
+ private TrajectoryParameterizer() {}
+
+ /**
+ * Parameterize the trajectory by time. This is where the velocity profile is generated.
+ *
+ * <p>The derivation of the algorithm used can be found <a
+ * href="http://www2.informatik.uni-freiburg.de/~lau/students/Sprunk2008.pdf">here</a>.
+ *
+ * @param points Reference to the spline points.
+ * @param constraints A vector of various velocity and acceleration. constraints.
+ * @param startVelocityMetersPerSecond The start velocity for the trajectory.
+ * @param endVelocityMetersPerSecond The end velocity for the trajectory.
+ * @param maxVelocityMetersPerSecond The max velocity for the trajectory.
+ * @param maxAccelerationMetersPerSecondSq The max acceleration for the trajectory.
+ * @param reversed Whether the robot should move backwards. Note that the robot will still move
+ * from a -> b -> ... -> z as defined in the waypoints.
+ * @return The trajectory.
+ */
+ public static Trajectory timeParameterizeTrajectory(
+ List<PoseWithCurvature> points,
+ List<TrajectoryConstraint> constraints,
+ double startVelocityMetersPerSecond,
+ double endVelocityMetersPerSecond,
+ double maxVelocityMetersPerSecond,
+ double maxAccelerationMetersPerSecondSq,
+ boolean reversed) {
+ var constrainedStates = new ArrayList<ConstrainedState>(points.size());
+ var predecessor =
+ new ConstrainedState(
+ points.get(0),
+ 0,
+ startVelocityMetersPerSecond,
+ -maxAccelerationMetersPerSecondSq,
+ maxAccelerationMetersPerSecondSq);
+
+ // Forward pass
+ for (int i = 0; i < points.size(); i++) {
+ constrainedStates.add(new ConstrainedState());
+ var constrainedState = constrainedStates.get(i);
+ constrainedState.pose = points.get(i);
+
+ // Begin constraining based on predecessor.
+ double ds =
+ constrainedState
+ .pose
+ .poseMeters
+ .getTranslation()
+ .getDistance(predecessor.pose.poseMeters.getTranslation());
+ constrainedState.distanceMeters = predecessor.distanceMeters + ds;
+
+ // We may need to iterate to find the maximum end velocity and common
+ // acceleration, since acceleration limits may be a function of velocity.
+ while (true) {
+ // Enforce global max velocity and max reachable velocity by global
+ // acceleration limit. vf = std::sqrt(vi^2 + 2*a*d).
+ constrainedState.maxVelocityMetersPerSecond =
+ Math.min(
+ maxVelocityMetersPerSecond,
+ Math.sqrt(
+ predecessor.maxVelocityMetersPerSecond * predecessor.maxVelocityMetersPerSecond
+ + predecessor.maxAccelerationMetersPerSecondSq * ds * 2.0));
+
+ constrainedState.minAccelerationMetersPerSecondSq = -maxAccelerationMetersPerSecondSq;
+ constrainedState.maxAccelerationMetersPerSecondSq = maxAccelerationMetersPerSecondSq;
+
+ // At this point, the constrained state is fully constructed apart from
+ // all the custom-defined user constraints.
+ for (final var constraint : constraints) {
+ constrainedState.maxVelocityMetersPerSecond =
+ Math.min(
+ constrainedState.maxVelocityMetersPerSecond,
+ constraint.getMaxVelocityMetersPerSecond(
+ constrainedState.pose.poseMeters,
+ constrainedState.pose.curvatureRadPerMeter,
+ constrainedState.maxVelocityMetersPerSecond));
+ }
+
+ // Now enforce all acceleration limits.
+ enforceAccelerationLimits(reversed, constraints, constrainedState);
+
+ if (ds < 1E-6) {
+ break;
+ }
+
+ // If the actual acceleration for this state is higher than the max
+ // acceleration that we applied, then we need to reduce the max
+ // acceleration of the predecessor and try again.
+ double actualAcceleration =
+ (constrainedState.maxVelocityMetersPerSecond
+ * constrainedState.maxVelocityMetersPerSecond
+ - predecessor.maxVelocityMetersPerSecond
+ * predecessor.maxVelocityMetersPerSecond)
+ / (ds * 2.0);
+
+ // If we violate the max acceleration constraint, let's modify the
+ // predecessor.
+ if (constrainedState.maxAccelerationMetersPerSecondSq < actualAcceleration - 1E-6) {
+ predecessor.maxAccelerationMetersPerSecondSq =
+ constrainedState.maxAccelerationMetersPerSecondSq;
+ } else {
+ // Constrain the predecessor's max acceleration to the current
+ // acceleration.
+ if (actualAcceleration > predecessor.minAccelerationMetersPerSecondSq) {
+ predecessor.maxAccelerationMetersPerSecondSq = actualAcceleration;
+ }
+ // If the actual acceleration is less than the predecessor's min
+ // acceleration, it will be repaired in the backward pass.
+ break;
+ }
+ }
+ predecessor = constrainedState;
+ }
+
+ var successor =
+ new ConstrainedState(
+ points.get(points.size() - 1),
+ constrainedStates.get(constrainedStates.size() - 1).distanceMeters,
+ endVelocityMetersPerSecond,
+ -maxAccelerationMetersPerSecondSq,
+ maxAccelerationMetersPerSecondSq);
+
+ // Backward pass
+ for (int i = points.size() - 1; i >= 0; i--) {
+ var constrainedState = constrainedStates.get(i);
+ double ds = constrainedState.distanceMeters - successor.distanceMeters; // negative
+
+ while (true) {
+ // Enforce max velocity limit (reverse)
+ // vf = std::sqrt(vi^2 + 2*a*d), where vi = successor.
+ double newMaxVelocity =
+ Math.sqrt(
+ successor.maxVelocityMetersPerSecond * successor.maxVelocityMetersPerSecond
+ + successor.minAccelerationMetersPerSecondSq * ds * 2.0);
+
+ // No more limits to impose! This state can be finalized.
+ if (newMaxVelocity >= constrainedState.maxVelocityMetersPerSecond) {
+ break;
+ }
+
+ constrainedState.maxVelocityMetersPerSecond = newMaxVelocity;
+
+ // Check all acceleration constraints with the new max velocity.
+ enforceAccelerationLimits(reversed, constraints, constrainedState);
+
+ if (ds > -1E-6) {
+ break;
+ }
+
+ // If the actual acceleration for this state is lower than the min
+ // acceleration, then we need to lower the min acceleration of the
+ // successor and try again.
+ double actualAcceleration =
+ (constrainedState.maxVelocityMetersPerSecond
+ * constrainedState.maxVelocityMetersPerSecond
+ - successor.maxVelocityMetersPerSecond * successor.maxVelocityMetersPerSecond)
+ / (ds * 2.0);
+
+ if (constrainedState.minAccelerationMetersPerSecondSq > actualAcceleration + 1E-6) {
+ successor.minAccelerationMetersPerSecondSq =
+ constrainedState.minAccelerationMetersPerSecondSq;
+ } else {
+ successor.minAccelerationMetersPerSecondSq = actualAcceleration;
+ break;
+ }
+ }
+ successor = constrainedState;
+ }
+
+ // Now we can integrate the constrained states forward in time to obtain our
+ // trajectory states.
+ var states = new ArrayList<Trajectory.State>(points.size());
+ double timeSeconds = 0.0;
+ double distanceMeters = 0.0;
+ double velocityMetersPerSecond = 0.0;
+
+ for (int i = 0; i < constrainedStates.size(); i++) {
+ final var state = constrainedStates.get(i);
+
+ // Calculate the change in position between the current state and the previous
+ // state.
+ double ds = state.distanceMeters - distanceMeters;
+
+ // Calculate the acceleration between the current state and the previous
+ // state.
+ double accel =
+ (state.maxVelocityMetersPerSecond * state.maxVelocityMetersPerSecond
+ - velocityMetersPerSecond * velocityMetersPerSecond)
+ / (ds * 2);
+
+ // Calculate dt
+ double dt = 0.0;
+ if (i > 0) {
+ states.get(i - 1).accelerationMetersPerSecondSq = reversed ? -accel : accel;
+ if (Math.abs(accel) > 1E-6) {
+ // v_f = v_0 + a * t
+ dt = (state.maxVelocityMetersPerSecond - velocityMetersPerSecond) / accel;
+ } else if (Math.abs(velocityMetersPerSecond) > 1E-6) {
+ // delta_x = v * t
+ dt = ds / velocityMetersPerSecond;
+ } else {
+ throw new TrajectoryGenerationException(
+ "Something went wrong at iteration " + i + " of time parameterization.");
+ }
+ }
+
+ velocityMetersPerSecond = state.maxVelocityMetersPerSecond;
+ distanceMeters = state.distanceMeters;
+
+ timeSeconds += dt;
+
+ states.add(
+ new Trajectory.State(
+ timeSeconds,
+ reversed ? -velocityMetersPerSecond : velocityMetersPerSecond,
+ reversed ? -accel : accel,
+ state.pose.poseMeters,
+ state.pose.curvatureRadPerMeter));
+ }
+
+ return new Trajectory(states);
+ }
+
+ private static void enforceAccelerationLimits(
+ boolean reverse, List<TrajectoryConstraint> constraints, ConstrainedState state) {
+ for (final var constraint : constraints) {
+ double factor = reverse ? -1.0 : 1.0;
+ final var minMaxAccel =
+ constraint.getMinMaxAccelerationMetersPerSecondSq(
+ state.pose.poseMeters,
+ state.pose.curvatureRadPerMeter,
+ state.maxVelocityMetersPerSecond * factor);
+
+ if (minMaxAccel.minAccelerationMetersPerSecondSq
+ > minMaxAccel.maxAccelerationMetersPerSecondSq) {
+ throw new TrajectoryGenerationException(
+ "The constraint's min acceleration "
+ + "was greater than its max acceleration.\n Offending Constraint: "
+ + constraint.getClass().getName()
+ + "\n If the offending constraint was packaged with WPILib, please file a bug"
+ + " report.");
+ }
+
+ state.minAccelerationMetersPerSecondSq =
+ Math.max(
+ state.minAccelerationMetersPerSecondSq,
+ reverse
+ ? -minMaxAccel.maxAccelerationMetersPerSecondSq
+ : minMaxAccel.minAccelerationMetersPerSecondSq);
+
+ state.maxAccelerationMetersPerSecondSq =
+ Math.min(
+ state.maxAccelerationMetersPerSecondSq,
+ reverse
+ ? -minMaxAccel.minAccelerationMetersPerSecondSq
+ : minMaxAccel.maxAccelerationMetersPerSecondSq);
+ }
+ }
+
+ @SuppressWarnings("MemberName")
+ private static class ConstrainedState {
+ PoseWithCurvature pose;
+ double distanceMeters;
+ double maxVelocityMetersPerSecond;
+ double minAccelerationMetersPerSecondSq;
+ double maxAccelerationMetersPerSecondSq;
+
+ ConstrainedState(
+ PoseWithCurvature pose,
+ double distanceMeters,
+ double maxVelocityMetersPerSecond,
+ double minAccelerationMetersPerSecondSq,
+ double maxAccelerationMetersPerSecondSq) {
+ this.pose = pose;
+ this.distanceMeters = distanceMeters;
+ this.maxVelocityMetersPerSecond = maxVelocityMetersPerSecond;
+ this.minAccelerationMetersPerSecondSq = minAccelerationMetersPerSecondSq;
+ this.maxAccelerationMetersPerSecondSq = maxAccelerationMetersPerSecondSq;
+ }
+
+ ConstrainedState() {
+ pose = new PoseWithCurvature();
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static class TrajectoryGenerationException extends RuntimeException {
+ public TrajectoryGenerationException(String message) {
+ super(message);
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryUtil.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryUtil.java
new file mode 100644
index 0000000..5fb2c34
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrajectoryUtil.java
@@ -0,0 +1,119 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import edu.wpi.first.math.WPIMathJNI;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+
+public final class TrajectoryUtil {
+ private TrajectoryUtil() {
+ throw new UnsupportedOperationException("This is a utility class!");
+ }
+
+ /**
+ * Creates a trajectory from a double[] of elements.
+ *
+ * @param elements A double[] containing the raw elements of the trajectory.
+ * @return A trajectory created from the elements.
+ */
+ private static Trajectory createTrajectoryFromElements(double[] elements) {
+ // Make sure that the elements have the correct length.
+ if (elements.length % 7 != 0) {
+ throw new TrajectorySerializationException(
+ "An error occurred when converting trajectory elements into a trajectory.");
+ }
+
+ // Create a list of states from the elements.
+ List<Trajectory.State> states = new ArrayList<>();
+ for (int i = 0; i < elements.length; i += 7) {
+ states.add(
+ new Trajectory.State(
+ elements[i],
+ elements[i + 1],
+ elements[i + 2],
+ new Pose2d(elements[i + 3], elements[i + 4], new Rotation2d(elements[i + 5])),
+ elements[i + 6]));
+ }
+ return new Trajectory(states);
+ }
+
+ /**
+ * Returns a double[] of elements from the given trajectory.
+ *
+ * @param trajectory The trajectory to retrieve raw elements from.
+ * @return A double[] of elements from the given trajectory.
+ */
+ private static double[] getElementsFromTrajectory(Trajectory trajectory) {
+ // Create a double[] of elements and fill it from the trajectory states.
+ double[] elements = new double[trajectory.getStates().size() * 7];
+
+ for (int i = 0; i < trajectory.getStates().size() * 7; i += 7) {
+ var state = trajectory.getStates().get(i / 7);
+ elements[i] = state.timeSeconds;
+ elements[i + 1] = state.velocityMetersPerSecond;
+ elements[i + 2] = state.accelerationMetersPerSecondSq;
+ elements[i + 3] = state.poseMeters.getX();
+ elements[i + 4] = state.poseMeters.getY();
+ elements[i + 5] = state.poseMeters.getRotation().getRadians();
+ elements[i + 6] = state.curvatureRadPerMeter;
+ }
+ return elements;
+ }
+
+ /**
+ * Imports a Trajectory from a PathWeaver-style JSON file.
+ *
+ * @param path The path of the json file to import from
+ * @return The trajectory represented by the file.
+ * @throws IOException if reading from the file fails.
+ */
+ public static Trajectory fromPathweaverJson(Path path) throws IOException {
+ return createTrajectoryFromElements(WPIMathJNI.fromPathweaverJson(path.toString()));
+ }
+
+ /**
+ * Exports a Trajectory to a PathWeaver-style JSON file.
+ *
+ * @param trajectory The trajectory to export
+ * @param path The path of the file to export to
+ * @throws IOException if writing to the file fails.
+ */
+ public static void toPathweaverJson(Trajectory trajectory, Path path) throws IOException {
+ WPIMathJNI.toPathweaverJson(getElementsFromTrajectory(trajectory), path.toString());
+ }
+
+ /**
+ * Deserializes a Trajectory from PathWeaver-style JSON.
+ *
+ * @param json The string containing the serialized JSON
+ * @return the trajectory represented by the JSON
+ * @throws TrajectorySerializationException if deserialization of the string fails.
+ */
+ public static Trajectory deserializeTrajectory(String json) {
+ return createTrajectoryFromElements(WPIMathJNI.deserializeTrajectory(json));
+ }
+
+ /**
+ * Serializes a Trajectory to PathWeaver-style JSON.
+ *
+ * @param trajectory The trajectory to export
+ * @return The string containing the serialized JSON
+ * @throws TrajectorySerializationException if serialization of the trajectory fails.
+ */
+ public static String serializeTrajectory(Trajectory trajectory) {
+ return WPIMathJNI.serializeTrajectory(getElementsFromTrajectory(trajectory));
+ }
+
+ public static class TrajectorySerializationException extends RuntimeException {
+ public TrajectorySerializationException(String message) {
+ super(message);
+ }
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrapezoidProfile.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrapezoidProfile.java
similarity index 75%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrapezoidProfile.java
rename to wpimath/src/main/java/edu/wpi/first/math/trajectory/TrapezoidProfile.java
index 8289d4d..35745b5 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrapezoidProfile.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/TrapezoidProfile.java
@@ -1,26 +1,22 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.Objects;
+package edu.wpi.first.math.trajectory;
import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.math.MathUsageId;
+import java.util.Objects;
/**
* A trapezoid-shaped velocity profile.
*
- * <p>While this class can be used for a profiled movement from start to finish,
- * the intended usage is to filter a reference's dynamics based on trapezoidal
- * velocity constraints. To compute the reference obeying this constraint, do
- * the following.
+ * <p>While this class can be used for a profiled movement from start to finish, the intended usage
+ * is to filter a reference's dynamics based on trapezoidal velocity constraints. To compute the
+ * reference obeying this constraint, do the following.
*
* <p>Initialization:
+ *
* <pre><code>
* TrapezoidProfile.Constraints constraints =
* new TrapezoidProfile.Constraints(kMaxV, kMaxA);
@@ -29,19 +25,18 @@
* </code></pre>
*
* <p>Run on update:
+ *
* <pre><code>
* TrapezoidProfile profile =
* new TrapezoidProfile(constraints, unprofiledReference, previousProfiledReference);
* previousProfiledReference = profile.calculate(timeSincePreviousUpdate);
* </code></pre>
*
- * <p>where `unprofiledReference` is free to change between calls. Note that when
- * the unprofiled reference is within the constraints, `calculate()` returns the
- * unprofiled reference unchanged.
+ * <p>where `unprofiledReference` is free to change between calls. Note that when the unprofiled
+ * reference is within the constraints, `calculate()` returns the unprofiled reference unchanged.
*
- * <p>Otherwise, a timer can be started to provide monotonic values for
- * `calculate()` and to determine when the profile has completed via
- * `isFinished()`.
+ * <p>Otherwise, a timer can be started to provide monotonic values for `calculate()` and to
+ * determine when the profile has completed via `isFinished()`.
*/
public class TrapezoidProfile {
// The direction of the profile, either 1 for forwards or -1 for inverted
@@ -57,13 +52,10 @@
public static class Constraints {
@SuppressWarnings("MemberName")
- public double maxVelocity;
- @SuppressWarnings("MemberName")
- public double maxAcceleration;
+ public final double maxVelocity;
- public Constraints() {
- MathSharedStore.reportUsage(MathUsageId.kTrajectory_TrapezoidProfile, 1);
- }
+ @SuppressWarnings("MemberName")
+ public final double maxAcceleration;
/**
* Construct constraints for a TrapezoidProfile.
@@ -81,11 +73,11 @@
public static class State {
@SuppressWarnings("MemberName")
public double position;
+
@SuppressWarnings("MemberName")
public double velocity;
- public State() {
- }
+ public State() {}
public State(double position, double velocity) {
this.position = position;
@@ -112,8 +104,8 @@
* Construct a TrapezoidProfile.
*
* @param constraints The constraints on the profile, like maximum velocity.
- * @param goal The desired state when the profile is complete.
- * @param initial The initial state (usually the current state).
+ * @param goal The desired state when the profile is complete.
+ * @param initial The initial state (usually the current state).
*/
public TrapezoidProfile(Constraints constraints, State goal, State initial) {
m_direction = shouldFlipAcceleration(initial, goal) ? -1 : 1;
@@ -137,12 +129,12 @@
// Now we can calculate the parameters as if it was a full trapezoid instead
// of a truncated one
- double fullTrapezoidDist = cutoffDistBegin + (m_goal.position - m_initial.position)
- + cutoffDistEnd;
+ double fullTrapezoidDist =
+ cutoffDistBegin + (m_goal.position - m_initial.position) + cutoffDistEnd;
double accelerationTime = m_constraints.maxVelocity / m_constraints.maxAcceleration;
- double fullSpeedDist = fullTrapezoidDist - accelerationTime * accelerationTime
- * m_constraints.maxAcceleration;
+ double fullSpeedDist =
+ fullTrapezoidDist - accelerationTime * accelerationTime * m_constraints.maxAcceleration;
// Handle the case where the profile never reaches full speed
if (fullSpeedDist < 0) {
@@ -159,19 +151,19 @@
* Construct a TrapezoidProfile.
*
* @param constraints The constraints on the profile, like maximum velocity.
- * @param goal The desired state when the profile is complete.
+ * @param goal The desired state when the profile is complete.
*/
public TrapezoidProfile(Constraints constraints, State goal) {
this(constraints, goal, new State(0, 0));
}
/**
- * Calculate the correct position and velocity for the profile at a time t
- * where the beginning of the profile was at time t = 0.
+ * Calculate the correct position and velocity for the profile at a time t where the beginning of
+ * the profile was at time t = 0.
*
* @param t The time since the beginning of the profile.
+ * @return The position and velocity of the profile at time t.
*/
- @SuppressWarnings("ParameterName")
public State calculate(double t) {
State result = new State(m_initial.position, m_initial.velocity);
@@ -180,13 +172,15 @@
result.position += (m_initial.velocity + t * m_constraints.maxAcceleration / 2.0) * t;
} else if (t < m_endFullSpeed) {
result.velocity = m_constraints.maxVelocity;
- result.position += (m_initial.velocity + m_endAccel * m_constraints.maxAcceleration
- / 2.0) * m_endAccel + m_constraints.maxVelocity * (t - m_endAccel);
+ result.position +=
+ (m_initial.velocity + m_endAccel * m_constraints.maxAcceleration / 2.0) * m_endAccel
+ + m_constraints.maxVelocity * (t - m_endAccel);
} else if (t <= m_endDeccel) {
result.velocity = m_goal.velocity + (m_endDeccel - t) * m_constraints.maxAcceleration;
double timeLeft = m_endDeccel - t;
- result.position = m_goal.position - (m_goal.velocity + timeLeft
- * m_constraints.maxAcceleration / 2.0) * timeLeft;
+ result.position =
+ m_goal.position
+ - (m_goal.velocity + timeLeft * m_constraints.maxAcceleration / 2.0) * timeLeft;
} else {
result = m_goal;
}
@@ -198,6 +192,7 @@
* Returns the time left until a target distance in the profile is reached.
*
* @param target The target distance.
+ * @return The time left until a target distance in the profile is reached.
*/
public double timeLeftUntil(double target) {
double position = m_initial.position * m_direction;
@@ -251,11 +246,15 @@
deccelDist = distToTarget - fullSpeedDist - accelDist;
}
- double accelTime = (-velocity + Math.sqrt(Math.abs(velocity * velocity + 2 * acceleration
- * accelDist))) / acceleration;
+ double accelTime =
+ (-velocity + Math.sqrt(Math.abs(velocity * velocity + 2 * acceleration * accelDist)))
+ / acceleration;
- double deccelTime = (-deccelVelocity + Math.sqrt(Math.abs(deccelVelocity * deccelVelocity
- + 2 * decceleration * deccelDist))) / decceleration;
+ double deccelTime =
+ (-deccelVelocity
+ + Math.sqrt(
+ Math.abs(deccelVelocity * deccelVelocity + 2 * decceleration * deccelDist)))
+ / decceleration;
double fullSpeedTime = fullSpeedDist / m_constraints.maxVelocity;
@@ -264,6 +263,8 @@
/**
* Returns the total time the profile takes to reach the goal.
+ *
+ * @return The total time the profile takes to reach the goal.
*/
public double totalTime() {
return m_endDeccel;
@@ -272,12 +273,12 @@
/**
* Returns true if the profile has reached the goal.
*
- * <p>The profile has reached the goal if the time since the profile started
- * has exceeded the profile's total time.
+ * <p>The profile has reached the goal if the time since the profile started has exceeded the
+ * profile's total time.
*
* @param t The time since the beginning of the profile.
+ * @return True if the profile has reached the goal.
*/
- @SuppressWarnings("ParameterName")
public boolean isFinished(double t) {
return t >= totalTime();
}
@@ -287,8 +288,8 @@
*
* <p>The profile is inverted if goal position is less than the initial position.
*
- * @param initial The initial state (usually the current state).
- * @param goal The desired state when the profile is complete.
+ * @param initial The initial state (usually the current state).
+ * @param goal The desired state when the profile is complete.
*/
private static boolean shouldFlipAcceleration(State initial, State goal) {
return initial.position > goal.position;
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/CentripetalAccelerationConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/CentripetalAccelerationConstraint.java
new file mode 100644
index 0000000..13138e8
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/CentripetalAccelerationConstraint.java
@@ -0,0 +1,68 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+
+/**
+ * A constraint on the maximum absolute centripetal acceleration allowed when traversing a
+ * trajectory. The centripetal acceleration of a robot is defined as the velocity squared divided by
+ * the radius of curvature.
+ *
+ * <p>Effectively, limiting the maximum centripetal acceleration will cause the robot to slow down
+ * around tight turns, making it easier to track trajectories with sharp turns.
+ */
+public class CentripetalAccelerationConstraint implements TrajectoryConstraint {
+ private final double m_maxCentripetalAccelerationMetersPerSecondSq;
+
+ /**
+ * Constructs a centripetal acceleration constraint.
+ *
+ * @param maxCentripetalAccelerationMetersPerSecondSq The max centripetal acceleration.
+ */
+ public CentripetalAccelerationConstraint(double maxCentripetalAccelerationMetersPerSecondSq) {
+ m_maxCentripetalAccelerationMetersPerSecondSq = maxCentripetalAccelerationMetersPerSecondSq;
+ }
+
+ /**
+ * Returns the max velocity given the current pose and curvature.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
+ * constraints are applied.
+ * @return The absolute maximum velocity.
+ */
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ // ac = v^2 / r
+ // k (curvature) = 1 / r
+
+ // therefore, ac = v^2 * k
+ // ac / k = v^2
+ // v = std::sqrt(ac / k)
+
+ return Math.sqrt(
+ m_maxCentripetalAccelerationMetersPerSecondSq / Math.abs(curvatureRadPerMeter));
+ }
+
+ /**
+ * Returns the minimum and maximum allowable acceleration for the trajectory given pose,
+ * curvature, and speed.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The speed at the current point in the trajectory.
+ * @return The min and max acceleration bounds.
+ */
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ // The acceleration of the robot has no impact on the centripetal acceleration
+ // of the robot.
+ return new MinMax();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/DifferentialDriveKinematicsConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/DifferentialDriveKinematicsConstraint.java
new file mode 100644
index 0000000..37d8d68
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/DifferentialDriveKinematicsConstraint.java
@@ -0,0 +1,71 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+
+/**
+ * A class that enforces constraints on the differential drive kinematics. This can be used to
+ * ensure that the trajectory is constructed so that the commanded velocities for both sides of the
+ * drivetrain stay below a certain limit.
+ */
+public class DifferentialDriveKinematicsConstraint implements TrajectoryConstraint {
+ private final double m_maxSpeedMetersPerSecond;
+ private final DifferentialDriveKinematics m_kinematics;
+
+ /**
+ * Constructs a differential drive dynamics constraint.
+ *
+ * @param kinematics A kinematics component describing the drive geometry.
+ * @param maxSpeedMetersPerSecond The max speed that a side of the robot can travel at.
+ */
+ public DifferentialDriveKinematicsConstraint(
+ final DifferentialDriveKinematics kinematics, double maxSpeedMetersPerSecond) {
+ m_maxSpeedMetersPerSecond = maxSpeedMetersPerSecond;
+ m_kinematics = kinematics;
+ }
+
+ /**
+ * Returns the max velocity given the current pose and curvature.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
+ * constraints are applied.
+ * @return The absolute maximum velocity.
+ */
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ // Create an object to represent the current chassis speeds.
+ var chassisSpeeds =
+ new ChassisSpeeds(
+ velocityMetersPerSecond, 0, velocityMetersPerSecond * curvatureRadPerMeter);
+
+ // Get the wheel speeds and normalize them to within the max velocity.
+ var wheelSpeeds = m_kinematics.toWheelSpeeds(chassisSpeeds);
+ wheelSpeeds.normalize(m_maxSpeedMetersPerSecond);
+
+ // Return the new linear chassis speed.
+ return m_kinematics.toChassisSpeeds(wheelSpeeds).vxMetersPerSecond;
+ }
+
+ /**
+ * Returns the minimum and maximum allowable acceleration for the trajectory given pose,
+ * curvature, and speed.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The speed at the current point in the trajectory.
+ * @return The min and max acceleration bounds.
+ */
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ return new MinMax();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/DifferentialDriveVoltageConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/DifferentialDriveVoltageConstraint.java
new file mode 100644
index 0000000..4c7e814
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/DifferentialDriveVoltageConstraint.java
@@ -0,0 +1,129 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import static edu.wpi.first.util.ErrorMessages.requireNonNullParam;
+
+import edu.wpi.first.math.controller.SimpleMotorFeedforward;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+
+/**
+ * A class that enforces constraints on differential drive voltage expenditure based on the motor
+ * dynamics and the drive kinematics. Ensures that the acceleration of any wheel of the robot while
+ * following the trajectory is never higher than what can be achieved with the given maximum
+ * voltage.
+ */
+public class DifferentialDriveVoltageConstraint implements TrajectoryConstraint {
+ private final SimpleMotorFeedforward m_feedforward;
+ private final DifferentialDriveKinematics m_kinematics;
+ private final double m_maxVoltage;
+
+ /**
+ * Creates a new DifferentialDriveVoltageConstraint.
+ *
+ * @param feedforward A feedforward component describing the behavior of the drive.
+ * @param kinematics A kinematics component describing the drive geometry.
+ * @param maxVoltage The maximum voltage available to the motors while following the path. Should
+ * be somewhat less than the nominal battery voltage (12V) to account for "voltage sag" due to
+ * current draw.
+ */
+ public DifferentialDriveVoltageConstraint(
+ SimpleMotorFeedforward feedforward,
+ DifferentialDriveKinematics kinematics,
+ double maxVoltage) {
+ m_feedforward =
+ requireNonNullParam(feedforward, "feedforward", "DifferentialDriveVoltageConstraint");
+ m_kinematics =
+ requireNonNullParam(kinematics, "kinematics", "DifferentialDriveVoltageConstraint");
+ m_maxVoltage = maxVoltage;
+ }
+
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ return Double.POSITIVE_INFINITY;
+ }
+
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ var wheelSpeeds =
+ m_kinematics.toWheelSpeeds(
+ new ChassisSpeeds(
+ velocityMetersPerSecond, 0, velocityMetersPerSecond * curvatureRadPerMeter));
+
+ double maxWheelSpeed =
+ Math.max(wheelSpeeds.leftMetersPerSecond, wheelSpeeds.rightMetersPerSecond);
+ double minWheelSpeed =
+ Math.min(wheelSpeeds.leftMetersPerSecond, wheelSpeeds.rightMetersPerSecond);
+
+ // Calculate maximum/minimum possible accelerations from motor dynamics
+ // and max/min wheel speeds
+ double maxWheelAcceleration =
+ m_feedforward.maxAchievableAcceleration(m_maxVoltage, maxWheelSpeed);
+ double minWheelAcceleration =
+ m_feedforward.minAchievableAcceleration(m_maxVoltage, minWheelSpeed);
+
+ // Robot chassis turning on radius = 1/|curvature|. Outer wheel has radius
+ // increased by half of the trackwidth T. Inner wheel has radius decreased
+ // by half of the trackwidth. Achassis / radius = Aouter / (radius + T/2), so
+ // Achassis = Aouter * radius / (radius + T/2) = Aouter / (1 + |curvature|T/2).
+ // Inner wheel is similar.
+
+ // sgn(speed) term added to correctly account for which wheel is on
+ // outside of turn:
+ // If moving forward, max acceleration constraint corresponds to wheel on outside of turn
+ // If moving backward, max acceleration constraint corresponds to wheel on inside of turn
+
+ // When velocity is zero, then wheel velocities are uniformly zero (robot cannot be
+ // turning on its center) - we have to treat this as a special case, as it breaks
+ // the signum function. Both max and min acceleration are *reduced in magnitude*
+ // in this case.
+
+ double maxChassisAcceleration;
+ double minChassisAcceleration;
+
+ if (velocityMetersPerSecond == 0) {
+ maxChassisAcceleration =
+ maxWheelAcceleration
+ / (1 + m_kinematics.trackWidthMeters * Math.abs(curvatureRadPerMeter) / 2);
+ minChassisAcceleration =
+ minWheelAcceleration
+ / (1 + m_kinematics.trackWidthMeters * Math.abs(curvatureRadPerMeter) / 2);
+ } else {
+ maxChassisAcceleration =
+ maxWheelAcceleration
+ / (1
+ + m_kinematics.trackWidthMeters
+ * Math.abs(curvatureRadPerMeter)
+ * Math.signum(velocityMetersPerSecond)
+ / 2);
+ minChassisAcceleration =
+ minWheelAcceleration
+ / (1
+ - m_kinematics.trackWidthMeters
+ * Math.abs(curvatureRadPerMeter)
+ * Math.signum(velocityMetersPerSecond)
+ / 2);
+ }
+
+ // When turning about a point inside of the wheelbase (i.e. radius less than half
+ // the trackwidth), the inner wheel's direction changes, but the magnitude remains
+ // the same. The formula above changes sign for the inner wheel when this happens.
+ // We can accurately account for this by simply negating the inner wheel.
+
+ if ((m_kinematics.trackWidthMeters / 2) > (1 / Math.abs(curvatureRadPerMeter))) {
+ if (velocityMetersPerSecond > 0) {
+ minChassisAcceleration = -minChassisAcceleration;
+ } else if (velocityMetersPerSecond < 0) {
+ maxChassisAcceleration = -maxChassisAcceleration;
+ }
+ }
+
+ return new MinMax(minChassisAcceleration, maxChassisAcceleration);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/EllipticalRegionConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/EllipticalRegionConstraint.java
new file mode 100644
index 0000000..c3bd226
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/EllipticalRegionConstraint.java
@@ -0,0 +1,77 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+
+/** Enforces a particular constraint only within an elliptical region. */
+public class EllipticalRegionConstraint implements TrajectoryConstraint {
+ private final Translation2d m_center;
+ private final Translation2d m_radii;
+ private final TrajectoryConstraint m_constraint;
+
+ /**
+ * Constructs a new EllipticalRegionConstraint.
+ *
+ * @param center The center of the ellipse in which to enforce the constraint.
+ * @param xWidth The width of the ellipse in which to enforce the constraint.
+ * @param yWidth The height of the ellipse in which to enforce the constraint.
+ * @param rotation The rotation to apply to all radii around the origin.
+ * @param constraint The constraint to enforce when the robot is within the region.
+ */
+ @SuppressWarnings("ParameterName")
+ public EllipticalRegionConstraint(
+ Translation2d center,
+ double xWidth,
+ double yWidth,
+ Rotation2d rotation,
+ TrajectoryConstraint constraint) {
+ m_center = center;
+ m_radii = new Translation2d(xWidth / 2.0, yWidth / 2.0).rotateBy(rotation);
+ m_constraint = constraint;
+ }
+
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ if (isPoseInRegion(poseMeters)) {
+ return m_constraint.getMaxVelocityMetersPerSecond(
+ poseMeters, curvatureRadPerMeter, velocityMetersPerSecond);
+ } else {
+ return Double.POSITIVE_INFINITY;
+ }
+ }
+
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ if (isPoseInRegion(poseMeters)) {
+ return m_constraint.getMinMaxAccelerationMetersPerSecondSq(
+ poseMeters, curvatureRadPerMeter, velocityMetersPerSecond);
+ } else {
+ return new MinMax();
+ }
+ }
+
+ /**
+ * Returns whether the specified robot pose is within the region that the constraint is enforced
+ * in.
+ *
+ * @param robotPose The robot pose.
+ * @return Whether the robot pose is within the constraint region.
+ */
+ public boolean isPoseInRegion(Pose2d robotPose) {
+ // The region (disk) bounded by the ellipse is given by the equation:
+ // ((x-h)^2)/Rx^2) + ((y-k)^2)/Ry^2) <= 1
+ // If the inequality is satisfied, then it is inside the ellipse; otherwise
+ // it is outside the ellipse.
+ // Both sides have been multiplied by Rx^2 * Ry^2 for efficiency reasons.
+ return Math.pow(robotPose.getX() - m_center.getX(), 2) * Math.pow(m_radii.getY(), 2)
+ + Math.pow(robotPose.getY() - m_center.getY(), 2) * Math.pow(m_radii.getX(), 2)
+ <= Math.pow(m_radii.getX(), 2) * Math.pow(m_radii.getY(), 2);
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/MaxVelocityConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/MaxVelocityConstraint.java
new file mode 100644
index 0000000..d672295
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/MaxVelocityConstraint.java
@@ -0,0 +1,37 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+
+/**
+ * Represents a constraint that enforces a max velocity. This can be composed with the {@link
+ * EllipticalRegionConstraint} or {@link RectangularRegionConstraint} to enforce a max velocity in a
+ * region.
+ */
+public class MaxVelocityConstraint implements TrajectoryConstraint {
+ private final double m_maxVelocity;
+
+ /**
+ * Constructs a new MaxVelocityConstraint.
+ *
+ * @param maxVelocityMetersPerSecond The max velocity.
+ */
+ public MaxVelocityConstraint(double maxVelocityMetersPerSecond) {
+ m_maxVelocity = maxVelocityMetersPerSecond;
+ }
+
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ return m_maxVelocity;
+ }
+
+ @Override
+ public TrajectoryConstraint.MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ return new MinMax();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/MecanumDriveKinematicsConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/MecanumDriveKinematicsConstraint.java
new file mode 100644
index 0000000..b26cdcf
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/MecanumDriveKinematicsConstraint.java
@@ -0,0 +1,79 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.MecanumDriveKinematics;
+
+/**
+ * A class that enforces constraints on the mecanum drive kinematics. This can be used to ensure
+ * that the trajectory is constructed so that the commanded velocities for all 4 wheels of the
+ * drivetrain stay below a certain limit.
+ */
+public class MecanumDriveKinematicsConstraint implements TrajectoryConstraint {
+ private final double m_maxSpeedMetersPerSecond;
+ private final MecanumDriveKinematics m_kinematics;
+
+ /**
+ * Constructs a mecanum drive kinematics constraint.
+ *
+ * @param kinematics Mecanum drive kinematics.
+ * @param maxSpeedMetersPerSecond The max speed that a side of the robot can travel at.
+ */
+ public MecanumDriveKinematicsConstraint(
+ final MecanumDriveKinematics kinematics, double maxSpeedMetersPerSecond) {
+ m_maxSpeedMetersPerSecond = maxSpeedMetersPerSecond;
+ m_kinematics = kinematics;
+ }
+
+ /**
+ * Returns the max velocity given the current pose and curvature.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
+ * constraints are applied.
+ * @return The absolute maximum velocity.
+ */
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ // Represents the velocity of the chassis in the x direction
+ var xdVelocity = velocityMetersPerSecond * poseMeters.getRotation().getCos();
+
+ // Represents the velocity of the chassis in the y direction
+ var ydVelocity = velocityMetersPerSecond * poseMeters.getRotation().getSin();
+
+ // Create an object to represent the current chassis speeds.
+ var chassisSpeeds =
+ new ChassisSpeeds(xdVelocity, ydVelocity, velocityMetersPerSecond * curvatureRadPerMeter);
+
+ // Get the wheel speeds and normalize them to within the max velocity.
+ var wheelSpeeds = m_kinematics.toWheelSpeeds(chassisSpeeds);
+ wheelSpeeds.normalize(m_maxSpeedMetersPerSecond);
+
+ // Convert normalized wheel speeds back to chassis speeds
+ var normSpeeds = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ // Return the new linear chassis speed.
+ return Math.hypot(normSpeeds.vxMetersPerSecond, normSpeeds.vyMetersPerSecond);
+ }
+
+ /**
+ * Returns the minimum and maximum allowable acceleration for the trajectory given pose,
+ * curvature, and speed.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The speed at the current point in the trajectory.
+ * @return The min and max acceleration bounds.
+ */
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ return new MinMax();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/RectangularRegionConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/RectangularRegionConstraint.java
new file mode 100644
index 0000000..b29df5e
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/RectangularRegionConstraint.java
@@ -0,0 +1,67 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Translation2d;
+
+/** Enforces a particular constraint only within a rectangular region. */
+public class RectangularRegionConstraint implements TrajectoryConstraint {
+ private final Translation2d m_bottomLeftPoint;
+ private final Translation2d m_topRightPoint;
+ private final TrajectoryConstraint m_constraint;
+
+ /**
+ * Constructs a new RectangularRegionConstraint.
+ *
+ * @param bottomLeftPoint The bottom left point of the rectangular region in which to enforce the
+ * constraint.
+ * @param topRightPoint The top right point of the rectangular region in which to enforce the
+ * constraint.
+ * @param constraint The constraint to enforce when the robot is within the region.
+ */
+ public RectangularRegionConstraint(
+ Translation2d bottomLeftPoint, Translation2d topRightPoint, TrajectoryConstraint constraint) {
+ m_bottomLeftPoint = bottomLeftPoint;
+ m_topRightPoint = topRightPoint;
+ m_constraint = constraint;
+ }
+
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ if (isPoseInRegion(poseMeters)) {
+ return m_constraint.getMaxVelocityMetersPerSecond(
+ poseMeters, curvatureRadPerMeter, velocityMetersPerSecond);
+ } else {
+ return Double.POSITIVE_INFINITY;
+ }
+ }
+
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ if (isPoseInRegion(poseMeters)) {
+ return m_constraint.getMinMaxAccelerationMetersPerSecondSq(
+ poseMeters, curvatureRadPerMeter, velocityMetersPerSecond);
+ } else {
+ return new MinMax();
+ }
+ }
+
+ /**
+ * Returns whether the specified robot pose is within the region that the constraint is enforced
+ * in.
+ *
+ * @param robotPose The robot pose.
+ * @return Whether the robot pose is within the constraint region.
+ */
+ public boolean isPoseInRegion(Pose2d robotPose) {
+ return robotPose.getX() >= m_bottomLeftPoint.getX()
+ && robotPose.getX() <= m_topRightPoint.getX()
+ && robotPose.getY() >= m_bottomLeftPoint.getY()
+ && robotPose.getY() <= m_topRightPoint.getY();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/SwerveDriveKinematicsConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/SwerveDriveKinematicsConstraint.java
new file mode 100644
index 0000000..5d95290
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/SwerveDriveKinematicsConstraint.java
@@ -0,0 +1,79 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.SwerveDriveKinematics;
+
+/**
+ * A class that enforces constraints on the swerve drive kinematics. This can be used to ensure that
+ * the trajectory is constructed so that the commanded velocities for all 4 wheels of the drivetrain
+ * stay below a certain limit.
+ */
+public class SwerveDriveKinematicsConstraint implements TrajectoryConstraint {
+ private final double m_maxSpeedMetersPerSecond;
+ private final SwerveDriveKinematics m_kinematics;
+
+ /**
+ * Constructs a swerve drive kinematics constraint.
+ *
+ * @param kinematics Swerve drive kinematics.
+ * @param maxSpeedMetersPerSecond The max speed that a side of the robot can travel at.
+ */
+ public SwerveDriveKinematicsConstraint(
+ final SwerveDriveKinematics kinematics, double maxSpeedMetersPerSecond) {
+ m_maxSpeedMetersPerSecond = maxSpeedMetersPerSecond;
+ m_kinematics = kinematics;
+ }
+
+ /**
+ * Returns the max velocity given the current pose and curvature.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
+ * constraints are applied.
+ * @return The absolute maximum velocity.
+ */
+ @Override
+ public double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ // Represents the velocity of the chassis in the x direction
+ var xdVelocity = velocityMetersPerSecond * poseMeters.getRotation().getCos();
+
+ // Represents the velocity of the chassis in the y direction
+ var ydVelocity = velocityMetersPerSecond * poseMeters.getRotation().getSin();
+
+ // Create an object to represent the current chassis speeds.
+ var chassisSpeeds =
+ new ChassisSpeeds(xdVelocity, ydVelocity, velocityMetersPerSecond * curvatureRadPerMeter);
+
+ // Get the wheel speeds and normalize them to within the max velocity.
+ var wheelSpeeds = m_kinematics.toSwerveModuleStates(chassisSpeeds);
+ SwerveDriveKinematics.normalizeWheelSpeeds(wheelSpeeds, m_maxSpeedMetersPerSecond);
+
+ // Convert normalized wheel speeds back to chassis speeds
+ var normSpeeds = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ // Return the new linear chassis speed.
+ return Math.hypot(normSpeeds.vxMetersPerSecond, normSpeeds.vyMetersPerSecond);
+ }
+
+ /**
+ * Returns the minimum and maximum allowable acceleration for the trajectory given pose,
+ * curvature, and speed.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The speed at the current point in the trajectory.
+ * @return The min and max acceleration bounds.
+ */
+ @Override
+ public MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
+ return new MinMax();
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/TrajectoryConstraint.java b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/TrajectoryConstraint.java
new file mode 100644
index 0000000..bbf30f7
--- /dev/null
+++ b/wpimath/src/main/java/edu/wpi/first/math/trajectory/constraint/TrajectoryConstraint.java
@@ -0,0 +1,59 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory.constraint;
+
+import edu.wpi.first.math.geometry.Pose2d;
+
+/**
+ * An interface for defining user-defined velocity and acceleration constraints while generating
+ * trajectories.
+ */
+public interface TrajectoryConstraint {
+ /**
+ * Returns the max velocity given the current pose and curvature.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
+ * constraints are applied.
+ * @return The absolute maximum velocity.
+ */
+ double getMaxVelocityMetersPerSecond(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond);
+
+ /**
+ * Returns the minimum and maximum allowable acceleration for the trajectory given pose,
+ * curvature, and speed.
+ *
+ * @param poseMeters The pose at the current point in the trajectory.
+ * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
+ * @param velocityMetersPerSecond The speed at the current point in the trajectory.
+ * @return The min and max acceleration bounds.
+ */
+ MinMax getMinMaxAccelerationMetersPerSecondSq(
+ Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond);
+
+ /** Represents a minimum and maximum acceleration. */
+ @SuppressWarnings("MemberName")
+ class MinMax {
+ public double minAccelerationMetersPerSecondSq = -Double.MAX_VALUE;
+ public double maxAccelerationMetersPerSecondSq = +Double.MAX_VALUE;
+
+ /**
+ * Constructs a MinMax.
+ *
+ * @param minAccelerationMetersPerSecondSq The minimum acceleration.
+ * @param maxAccelerationMetersPerSecondSq The maximum acceleration.
+ */
+ public MinMax(
+ double minAccelerationMetersPerSecondSq, double maxAccelerationMetersPerSecondSq) {
+ this.minAccelerationMetersPerSecondSq = minAccelerationMetersPerSecondSq;
+ this.maxAccelerationMetersPerSecondSq = maxAccelerationMetersPerSecondSq;
+ }
+
+ /** Constructs a MinMax with default values. */
+ public MinMax() {}
+ }
+}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/util/Units.java b/wpimath/src/main/java/edu/wpi/first/math/util/Units.java
similarity index 61%
rename from wpimath/src/main/java/edu/wpi/first/wpilibj/util/Units.java
rename to wpimath/src/main/java/edu/wpi/first/math/util/Units.java
index 3f48306..13adcd8 100644
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/util/Units.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/util/Units.java
@@ -1,23 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.util;
+package edu.wpi.first.math.util;
-/**
- * Utility class that converts between commonly used units in FRC.
- */
+/** Utility class that converts between commonly used units in FRC. */
public final class Units {
private static final double kInchesPerFoot = 12.0;
private static final double kMetersPerInch = 0.0254;
private static final double kSecondsPerMinute = 60;
+ private static final double kMillisecondsPerSecond = 1000;
+ private static final double kKilogramsPerLb = 0.453592;
- /**
- * Utility class, so constructor is private.
- */
+ /** Utility class, so constructor is private. */
private Units() {
throw new UnsupportedOperationException("This is a utility class!");
}
@@ -101,4 +96,44 @@
public static double radiansPerSecondToRotationsPerMinute(double radiansPerSecond) {
return radiansPerSecond * (kSecondsPerMinute / 2) / Math.PI;
}
+
+ /**
+ * Converts given milliseconds to seconds.
+ *
+ * @param milliseconds The milliseconds to convert to seconds.
+ * @return Seconds converted from milliseconds.
+ */
+ public static double millisecondsToSeconds(double milliseconds) {
+ return milliseconds / kMillisecondsPerSecond;
+ }
+
+ /**
+ * Converts given seconds to milliseconds.
+ *
+ * @param seconds The seconds to convert to milliseconds.
+ * @return Milliseconds converted from seconds.
+ */
+ public static double secondsToMilliseconds(double seconds) {
+ return seconds * kMillisecondsPerSecond;
+ }
+
+ /**
+ * Converts kilograms into lbs (pound-mass).
+ *
+ * @param kilograms The kilograms to convert to lbs (pound-mass).
+ * @return Lbs (pound-mass) converted from kilograms.
+ */
+ public static double kilogramsToLbs(double kilograms) {
+ return kilograms / kKilogramsPerLb;
+ }
+
+ /**
+ * Converts lbs (pound-mass) into kilograms.
+ *
+ * @param lbs The lbs (pound-mass) to convert to kilograms.
+ * @return Kilograms converted from lbs (pound-mass).
+ */
+ public static double lbsToKilograms(double lbs) {
+ return lbs * kKilogramsPerLb;
+ }
}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/LinearFilter.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/LinearFilter.java
deleted file mode 100644
index 10897e8..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/LinearFilter.java
+++ /dev/null
@@ -1,171 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2015-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj;
-
-import java.util.Arrays;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-import edu.wpi.first.wpiutil.CircularBuffer;
-
-/**
- * This class implements a linear, digital filter. All types of FIR and IIR filters are supported.
- * Static factory methods are provided to create commonly used types of filters.
- *
- * <p>Filters are of the form: y[n] = (b0*x[n] + b1*x[n-1] + ... + bP*x[n-P]) - (a0*y[n-1] +
- * a2*y[n-2] + ... + aQ*y[n-Q])
- *
- * <p>Where: y[n] is the output at time "n" x[n] is the input at time "n" y[n-1] is the output from
- * the LAST time step ("n-1") x[n-1] is the input from the LAST time step ("n-1") b0...bP are the
- * "feedforward" (FIR) gains a0...aQ are the "feedback" (IIR) gains IMPORTANT! Note the "-" sign in
- * front of the feedback term! This is a common convention in signal processing.
- *
- * <p>What can linear filters do? Basically, they can filter, or diminish, the effects of
- * undesirable input frequencies. High frequencies, or rapid changes, can be indicative of sensor
- * noise or be otherwise undesirable. A "low pass" filter smooths out the signal, reducing the
- * impact of these high frequency components. Likewise, a "high pass" filter gets rid of
- * slow-moving signal components, letting you detect large changes more easily.
- *
- * <p>Example FRC applications of filters: - Getting rid of noise from an analog sensor input (note:
- * the roboRIO's FPGA can do this faster in hardware) - Smoothing out joystick input to prevent the
- * wheels from slipping or the robot from tipping - Smoothing motor commands so that unnecessary
- * strain isn't put on electrical or mechanical components - If you use clever gains, you can make a
- * PID controller out of this class!
- *
- * <p>For more on filters, we highly recommend the following articles:<br>
- * https://en.wikipedia.org/wiki/Linear_filter<br>
- * https://en.wikipedia.org/wiki/Iir_filter<br>
- * https://en.wikipedia.org/wiki/Fir_filter<br>
- *
- * <p>Note 1: calculate() should be called by the user on a known, regular period. You can use a
- * Notifier for this or do it "inline" with code in a periodic function.
- *
- * <p>Note 2: For ALL filters, gains are necessarily a function of frequency. If you make a filter
- * that works well for you at, say, 100Hz, you will most definitely need to adjust the gains if you
- * then want to run it at 200Hz! Combining this with Note 1 - the impetus is on YOU as a developer
- * to make sure calculate() gets called at the desired, constant frequency!
- */
-public class LinearFilter {
- private final CircularBuffer m_inputs;
- private final CircularBuffer m_outputs;
- private final double[] m_inputGains;
- private final double[] m_outputGains;
-
- private static int instances;
-
- /**
- * Create a linear FIR or IIR filter.
- *
- * @param ffGains The "feed forward" or FIR gains.
- * @param fbGains The "feed back" or IIR gains.
- */
- public LinearFilter(double[] ffGains, double[] fbGains) {
- m_inputs = new CircularBuffer(ffGains.length);
- m_outputs = new CircularBuffer(fbGains.length);
- m_inputGains = Arrays.copyOf(ffGains, ffGains.length);
- m_outputGains = Arrays.copyOf(fbGains, fbGains.length);
-
- instances++;
- MathSharedStore.reportUsage(MathUsageId.kFilter_Linear, instances);
- }
-
- /**
- * Creates a one-pole IIR low-pass filter of the form: y[n] = (1-gain)*x[n] + gain*y[n-1] where
- * gain = e^(-dt / T), T is the time constant in seconds.
- *
- * <p>This filter is stable for time constants greater than zero.
- *
- * @param timeConstant The discrete-time time constant in seconds.
- * @param period The period in seconds between samples taken by the user.
- */
- public static LinearFilter singlePoleIIR(double timeConstant,
- double period) {
- double gain = Math.exp(-period / timeConstant);
- double[] ffGains = {1.0 - gain};
- double[] fbGains = {-gain};
-
- return new LinearFilter(ffGains, fbGains);
- }
-
- /**
- * Creates a first-order high-pass filter of the form: y[n] = gain*x[n] + (-gain)*x[n-1] +
- * gain*y[n-1] where gain = e^(-dt / T), T is the time constant in seconds.
- *
- * <p>This filter is stable for time constants greater than zero.
- *
- * @param timeConstant The discrete-time time constant in seconds.
- * @param period The period in seconds between samples taken by the user.
- */
- public static LinearFilter highPass(double timeConstant,
- double period) {
- double gain = Math.exp(-period / timeConstant);
- double[] ffGains = {gain, -gain};
- double[] fbGains = {-gain};
-
- return new LinearFilter(ffGains, fbGains);
- }
-
- /**
- * Creates a K-tap FIR moving average filter of the form: y[n] = 1/k * (x[k] + x[k-1] + ... +
- * x[0]).
- *
- * <p>This filter is always stable.
- *
- * @param taps The number of samples to average over. Higher = smoother but slower.
- * @throws IllegalArgumentException if number of taps is less than 1.
- */
- public static LinearFilter movingAverage(int taps) {
- if (taps <= 0) {
- throw new IllegalArgumentException("Number of taps was not at least 1");
- }
-
- double[] ffGains = new double[taps];
- for (int i = 0; i < ffGains.length; i++) {
- ffGains[i] = 1.0 / taps;
- }
-
- double[] fbGains = new double[0];
-
- return new LinearFilter(ffGains, fbGains);
- }
-
- /**
- * Reset the filter state.
- */
- public void reset() {
- m_inputs.clear();
- m_outputs.clear();
- }
-
- /**
- * Calculates the next value of the filter.
- *
- * @param input Current input value.
- *
- * @return The filtered value at this step
- */
- public double calculate(double input) {
- double retVal = 0.0;
-
- // Rotate the inputs
- m_inputs.addFirst(input);
-
- // Calculate the new value
- for (int i = 0; i < m_inputGains.length; i++) {
- retVal += m_inputs.get(i) * m_inputGains[i];
- }
- for (int i = 0; i < m_outputGains.length; i++) {
- retVal -= m_outputs.get(i) * m_outputGains[i];
- }
-
- // Rotate the outputs
- m_outputs.addFirst(retVal);
-
- return retVal;
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ArmFeedforward.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ArmFeedforward.java
deleted file mode 100644
index 59e927a..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ArmFeedforward.java
+++ /dev/null
@@ -1,144 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-/**
- * A helper class that computes feedforward outputs for a simple arm (modeled as a motor
- * acting against the force of gravity on a beam suspended at an angle).
- */
-@SuppressWarnings("MemberName")
-public class ArmFeedforward {
- public final double ks;
- public final double kcos;
- public final double kv;
- public final double ka;
-
- /**
- * Creates a new ArmFeedforward with the specified gains. Units of the gain values
- * will dictate units of the computed feedforward.
- *
- * @param ks The static gain.
- * @param kcos The gravity gain.
- * @param kv The velocity gain.
- * @param ka The acceleration gain.
- */
- public ArmFeedforward(double ks, double kcos, double kv, double ka) {
- this.ks = ks;
- this.kcos = kcos;
- this.kv = kv;
- this.ka = ka;
- }
-
- /**
- * Creates a new ArmFeedforward with the specified gains. Acceleration gain is
- * defaulted to zero. Units of the gain values will dictate units of the computed feedforward.
- *
- * @param ks The static gain.
- * @param kcos The gravity gain.
- * @param kv The velocity gain.
- */
- public ArmFeedforward(double ks, double kcos, double kv) {
- this(ks, kcos, kv, 0);
- }
-
- /**
- * Calculates the feedforward from the gains and setpoints.
- *
- * @param positionRadians The position (angle) setpoint.
- * @param velocityRadPerSec The velocity setpoint.
- * @param accelRadPerSecSquared The acceleration setpoint.
- * @return The computed feedforward.
- */
- public double calculate(double positionRadians, double velocityRadPerSec,
- double accelRadPerSecSquared) {
- return ks * Math.signum(velocityRadPerSec) + kcos * Math.cos(positionRadians)
- + kv * velocityRadPerSec
- + ka * accelRadPerSecSquared;
- }
-
- /**
- * Calculates the feedforward from the gains and velocity setpoint (acceleration is assumed to
- * be zero).
- *
- * @param positionRadians The position (angle) setpoint.
- * @param velocity The velocity setpoint.
- * @return The computed feedforward.
- */
- public double calculate(double positionRadians, double velocity) {
- return calculate(positionRadians, velocity, 0);
- }
-
- // Rearranging the main equation from the calculate() method yields the
- // formulas for the methods below:
-
- /**
- * Calculates the maximum achievable velocity given a maximum voltage supply,
- * a position, and an acceleration. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the acceleration constraint, and this will give you
- * a simultaneously-achievable velocity constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the arm.
- * @param angle The angle of the arm.
- * @param acceleration The acceleration of the arm.
- * @return The maximum possible velocity at the given acceleration and angle.
- */
- public double maxAchievableVelocity(double maxVoltage, double angle, double acceleration) {
- // Assume max velocity is positive
- return (maxVoltage - ks - Math.cos(angle) * kcos - acceleration * ka) / kv;
- }
-
- /**
- * Calculates the minimum achievable velocity given a maximum voltage supply,
- * a position, and an acceleration. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the acceleration constraint, and this will give you
- * a simultaneously-achievable velocity constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the arm.
- * @param angle The angle of the arm.
- * @param acceleration The acceleration of the arm.
- * @return The minimum possible velocity at the given acceleration and angle.
- */
- public double minAchievableVelocity(double maxVoltage, double angle, double acceleration) {
- // Assume min velocity is negative, ks flips sign
- return (-maxVoltage + ks - Math.cos(angle) * kcos - acceleration * ka) / kv;
- }
-
- /**
- * Calculates the maximum achievable acceleration given a maximum voltage
- * supply, a position, and a velocity. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the velocity constraint, and this will give you
- * a simultaneously-achievable acceleration constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the arm.
- * @param angle The angle of the arm.
- * @param velocity The velocity of the arm.
- * @return The maximum possible acceleration at the given velocity.
- */
- public double maxAchievableAcceleration(double maxVoltage, double angle, double velocity) {
- return (maxVoltage - ks * Math.signum(velocity) - Math.cos(angle) * kcos - velocity * kv) / ka;
- }
-
- /**
- * Calculates the minimum achievable acceleration given a maximum voltage
- * supply, a position, and a velocity. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the velocity constraint, and this will give you
- * a simultaneously-achievable acceleration constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the arm.
- * @param angle The angle of the arm.
- * @param velocity The velocity of the arm.
- * @return The minimum possible acceleration at the given velocity.
- */
- public double minAchievableAcceleration(double maxVoltage, double angle, double velocity) {
- return maxAchievableAcceleration(-maxVoltage, angle, velocity);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ControlAffinePlantInversionFeedforward.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ControlAffinePlantInversionFeedforward.java
deleted file mode 100644
index 79a88cd..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/ControlAffinePlantInversionFeedforward.java
+++ /dev/null
@@ -1,215 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-import java.util.function.BiFunction;
-import java.util.function.Function;
-
-import edu.wpi.first.wpilibj.system.NumericalJacobian;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-/**
- * Constructs a control-affine plant inversion model-based feedforward from
- * given model dynamics.
- *
- * <p>If given the vector valued function as f(x, u) where x is the state
- * vector and u is the input vector, the B matrix(continuous input matrix)
- * is calculated through a {@link edu.wpi.first.wpilibj.system.NumericalJacobian}.
- * In this case f has to be control-affine (of the form f(x) + Bu).
- *
- * <p>The feedforward is calculated as
- * <strong> u_ff = B<sup>+</sup> (rDot - f(x))</strong>, where
- * <strong> B<sup>+</sup> </strong> is the pseudoinverse of B.
- *
- * <p>This feedforward does not account for a dynamic B matrix, B is either
- * determined or supplied when the feedforward is created and remains constant.
- *
- * <p>For more on the underlying math, read
- * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
- */
-@SuppressWarnings({"ParameterName", "LocalVariableName", "MemberName", "ClassTypeParameterName"})
-public class ControlAffinePlantInversionFeedforward<States extends Num, Inputs extends Num> {
- /**
- * The current reference state.
- */
- @SuppressWarnings("MemberName")
- private Matrix<States, N1> m_r;
-
- /**
- * The computed feedforward.
- */
- private Matrix<Inputs, N1> m_uff;
-
- @SuppressWarnings("MemberName")
- private final Matrix<States, Inputs> m_B;
-
- private final Nat<Inputs> m_inputs;
-
- private final double m_dt;
-
- /**
- * The model dynamics.
- */
- private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
-
- /**
- * Constructs a feedforward with given model dynamics as a function
- * of state and input.
- *
- * @param states A {@link Nat} representing the number of states.
- * @param inputs A {@link Nat} representing the number of inputs.
- * @param f A vector-valued function of x, the state, and
- * u, the input, that returns the derivative of
- * the state vector. HAS to be control-affine
- * (of the form f(x) + Bu).
- * @param dtSeconds The timestep between calls of calculate().
- */
- public ControlAffinePlantInversionFeedforward(
- Nat<States> states,
- Nat<Inputs> inputs,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
- double dtSeconds) {
- this.m_dt = dtSeconds;
- this.m_f = f;
- this.m_inputs = inputs;
-
- this.m_B = NumericalJacobian.numericalJacobianU(states, inputs,
- m_f, new Matrix<>(states, Nat.N1()), new Matrix<>(inputs, Nat.N1()));
-
- m_r = new Matrix<>(states, Nat.N1());
- m_uff = new Matrix<>(inputs, Nat.N1());
-
- reset(m_r);
- }
-
- /**
- * Constructs a feedforward with given model dynamics as a function of state,
- * and the plant's B(continuous input matrix) matrix.
- *
- * @param states A {@link Nat} representing the number of states.
- * @param inputs A {@link Nat} representing the number of inputs.
- * @param f A vector-valued function of x, the state,
- * that returns the derivative of the state vector.
- * @param B Continuous input matrix of the plant being controlled.
- * @param dtSeconds The timestep between calls of calculate().
- */
- public ControlAffinePlantInversionFeedforward(
- Nat<States> states,
- Nat<Inputs> inputs,
- Function<Matrix<States, N1>, Matrix<States, N1>> f,
- Matrix<States, Inputs> B,
- double dtSeconds) {
- this.m_dt = dtSeconds;
- this.m_inputs = inputs;
-
- this.m_f = (x, u) -> f.apply(x);
- this.m_B = B;
-
- m_r = new Matrix<>(states, Nat.N1());
- m_uff = new Matrix<>(inputs, Nat.N1());
-
- reset(m_r);
- }
-
-
- /**
- * Returns the previously calculated feedforward as an input vector.
- *
- * @return The calculated feedforward.
- */
- public Matrix<Inputs, N1> getUff() {
- return m_uff;
- }
-
- /**
- * Returns an element of the previously calculated feedforward.
- *
- * @param row Row of uff.
- *
- * @return The row of the calculated feedforward.
- */
- public double getUff(int row) {
- return m_uff.get(row, 0);
- }
-
- /**
- * Returns the current reference vector r.
- *
- * @return The current reference vector.
- */
- public Matrix<States, N1> getR() {
- return m_r;
- }
-
- /**
- * Returns an element of the current reference vector r.
- *
- * @param row Row of r.
- *
- * @return The row of the current reference vector.
- */
- public double getR(int row) {
- return m_r.get(row, 0);
- }
-
- /**
- * Resets the feedforward with a specified initial state vector.
- *
- * @param initialState The initial state vector.
- */
- public void reset(Matrix<States, N1> initialState) {
- m_r = initialState;
- m_uff.fill(0.0);
- }
-
- /**
- * Resets the feedforward with a zero initial state vector.
- */
- public void reset() {
- m_r.fill(0.0);
- m_uff.fill(0.0);
- }
-
- /**
- * Calculate the feedforward with only the desired
- * future reference. This uses the internally stored "current"
- * reference.
- *
- * <p>If this method is used the initial state of the system is the one
- * set using {@link LinearPlantInversionFeedforward#reset(Matrix)}.
- * If the initial state is not set it defaults to a zero vector.
- *
- * @param nextR The reference state of the future timestep (k + dt).
- *
- * @return The calculated feedforward.
- */
- public Matrix<Inputs, N1> calculate(Matrix<States, N1> nextR) {
- return calculate(m_r, nextR);
- }
-
- /**
- * Calculate the feedforward with current and future reference vectors.
- *
- * @param r The reference state of the current timestep (k).
- * @param nextR The reference state of the future timestep (k + dt).
- *
- * @return The calculated feedforward.
- */
- @SuppressWarnings({"ParameterName", "LocalVariableName"})
- public Matrix<Inputs, N1> calculate(Matrix<States, N1> r, Matrix<States, N1> nextR) {
- var rDot = (nextR.minus(r)).div(m_dt);
-
- m_uff = m_B.solve(rDot.minus(m_f.apply(r, new Matrix<>(m_inputs, Nat.N1()))));
-
- m_r = nextR;
- return m_uff;
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/LinearQuadraticRegulator.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/LinearQuadraticRegulator.java
deleted file mode 100644
index 195ba4f..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/LinearQuadraticRegulator.java
+++ /dev/null
@@ -1,214 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.math.Drake;
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.Vector;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-
-/**
- * Contains the controller coefficients and logic for a linear-quadratic
- * regulator (LQR).
- * LQRs use the control law u = K(r - x).
- *
- * <p>For more on the underlying math, read
- * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
- */
-@SuppressWarnings("ClassTypeParameterName")
-public class LinearQuadraticRegulator<States extends Num, Inputs extends Num,
- Outputs extends Num> {
- /**
- * The current reference state.
- */
- @SuppressWarnings("MemberName")
- private Matrix<States, N1> m_r;
-
- /**
- * The computed and capped controller output.
- */
- @SuppressWarnings("MemberName")
- private Matrix<Inputs, N1> m_u;
-
- // Controller gain.
- @SuppressWarnings("MemberName")
- private Matrix<Inputs, States> m_K;
-
- /**
- * Constructs a controller with the given coefficients and plant. Rho is defaulted to 1.
- *
- * @param plant The plant being controlled.
- * @param qelms The maximum desired error tolerance for each state.
- * @param relms The maximum desired control effort for each input.
- * @param dtSeconds Discretization timestep.
- */
- public LinearQuadraticRegulator(
- LinearSystem<States, Inputs, Outputs> plant,
- Vector<States> qelms,
- Vector<Inputs> relms,
- double dtSeconds
- ) {
- this(plant.getA(), plant.getB(), StateSpaceUtil.makeCostMatrix(qelms),
- StateSpaceUtil.makeCostMatrix(relms), dtSeconds);
- }
-
- /**
- * Constructs a controller with the given coefficients and plant.
- *
- * @param A Continuous system matrix of the plant being controlled.
- * @param B Continuous input matrix of the plant being controlled.
- * @param qelms The maximum desired error tolerance for each state.
- * @param relms The maximum desired control effort for each input.
- * @param dtSeconds Discretization timestep.
- */
- @SuppressWarnings({"ParameterName", "LocalVariableName"})
- public LinearQuadraticRegulator(Matrix<States, States> A, Matrix<States, Inputs> B,
- Vector<States> qelms, Vector<Inputs> relms,
- double dtSeconds
- ) {
- this(A, B, StateSpaceUtil.makeCostMatrix(qelms), StateSpaceUtil.makeCostMatrix(relms),
- dtSeconds);
- }
-
- /**
- * Constructs a controller with the given coefficients and plant.
- * @param A Continuous system matrix of the plant being controlled.
- * @param B Continuous input matrix of the plant being controlled.
- * @param Q The state cost matrix.
- * @param R The input cost matrix.
- * @param dtSeconds Discretization timestep.
- */
- @SuppressWarnings({"ParameterName", "LocalVariableName"})
- public LinearQuadraticRegulator(Matrix<States, States> A, Matrix<States, Inputs> B,
- Matrix<States, States> Q, Matrix<Inputs, Inputs> R,
- double dtSeconds
- ) {
- var discABPair = Discretization.discretizeAB(A, B, dtSeconds);
- var discA = discABPair.getFirst();
- var discB = discABPair.getSecond();
-
- var S = Drake.discreteAlgebraicRiccatiEquation(discA, discB, Q, R);
-
- var temp = discB.transpose().times(S).times(discB).plus(R);
-
- m_K = temp.solve(discB.transpose().times(S).times(discA));
-
- m_r = new Matrix<>(new SimpleMatrix(B.getNumRows(), 1));
- m_u = new Matrix<>(new SimpleMatrix(B.getNumCols(), 1));
-
- reset();
- }
-
- /**
- * Constructs a controller with the given coefficients and plant.
- *
- * @param states The number of states.
- * @param inputs The number of inputs.
- * @param k The gain matrix.
- */
- @SuppressWarnings("ParameterName")
- public LinearQuadraticRegulator(
- Nat<States> states, Nat<Inputs> inputs,
- Matrix<Inputs, States> k
- ) {
- m_K = k;
-
- m_r = new Matrix<>(states, Nat.N1());
- m_u = new Matrix<>(inputs, Nat.N1());
-
- reset();
- }
-
- /**
- * Returns the control input vector u.
- *
- * @return The control input.
- */
- public Matrix<Inputs, N1> getU() {
- return m_u;
- }
-
- /**
- * Returns an element of the control input vector u.
- *
- * @param row Row of u.
- *
- * @return The row of the control input vector.
- */
- public double getU(int row) {
- return m_u.get(row, 0);
- }
-
- /**
- * Returns the reference vector r.
- *
- * @return The reference vector.
- */
- public Matrix<States, N1> getR() {
- return m_r;
- }
-
- /**
- * Returns an element of the reference vector r.
- *
- * @param row Row of r.
- *
- * @return The row of the reference vector.
- */
- public double getR(int row) {
- return m_r.get(row, 0);
- }
-
- /**
- * Returns the controller matrix K.
- *
- * @return the controller matrix K.
- */
- public Matrix<Inputs, States> getK() {
- return m_K;
- }
-
- /**
- * Resets the controller.
- */
- public void reset() {
- m_r.fill(0.0);
- m_u.fill(0.0);
- }
-
- /**
- * Returns the next output of the controller.
- *
- * @param x The current state x.
- */
- @SuppressWarnings("ParameterName")
- public Matrix<Inputs, N1> calculate(Matrix<States, N1> x) {
- m_u = m_K.times(m_r.minus(x));
- return m_u;
- }
-
- /**
- * Returns the next output of the controller.
- *
- * @param x The current state x.
- * @param nextR the next reference vector r.
- */
- @SuppressWarnings("ParameterName")
- public Matrix<Inputs, N1> calculate(Matrix<States, N1> x, Matrix<States, N1> nextR) {
- m_r = nextR;
- return calculate(x);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/SimpleMotorFeedforward.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/SimpleMotorFeedforward.java
deleted file mode 100644
index ec53d46..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/controller/SimpleMotorFeedforward.java
+++ /dev/null
@@ -1,130 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-/**
- * A helper class that computes feedforward outputs for a simple permanent-magnet DC motor.
- */
-@SuppressWarnings("MemberName")
-public class SimpleMotorFeedforward {
- public final double ks;
- public final double kv;
- public final double ka;
-
- /**
- * Creates a new SimpleMotorFeedforward with the specified gains. Units of the gain values
- * will dictate units of the computed feedforward.
- *
- * @param ks The static gain.
- * @param kv The velocity gain.
- * @param ka The acceleration gain.
- */
- public SimpleMotorFeedforward(double ks, double kv, double ka) {
- this.ks = ks;
- this.kv = kv;
- this.ka = ka;
- }
-
- /**
- * Creates a new SimpleMotorFeedforward with the specified gains. Acceleration gain is
- * defaulted to zero. Units of the gain values will dictate units of the computed feedforward.
- *
- * @param ks The static gain.
- * @param kv The velocity gain.
- */
- public SimpleMotorFeedforward(double ks, double kv) {
- this(ks, kv, 0);
- }
-
- /**
- * Calculates the feedforward from the gains and setpoints.
- *
- * @param velocity The velocity setpoint.
- * @param acceleration The acceleration setpoint.
- * @return The computed feedforward.
- */
- public double calculate(double velocity, double acceleration) {
- return ks * Math.signum(velocity) + kv * velocity + ka * acceleration;
- }
-
- // Rearranging the main equation from the calculate() method yields the
- // formulas for the methods below:
-
- /**
- * Calculates the feedforward from the gains and velocity setpoint (acceleration is assumed to
- * be zero).
- *
- * @param velocity The velocity setpoint.
- * @return The computed feedforward.
- */
- public double calculate(double velocity) {
- return calculate(velocity, 0);
- }
-
- /**
- * Calculates the maximum achievable velocity given a maximum voltage supply
- * and an acceleration. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the acceleration constraint, and this will give you
- * a simultaneously-achievable velocity constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the motor.
- * @param acceleration The acceleration of the motor.
- * @return The maximum possible velocity at the given acceleration.
- */
- public double maxAchievableVelocity(double maxVoltage, double acceleration) {
- // Assume max velocity is positive
- return (maxVoltage - ks - acceleration * ka) / kv;
- }
-
- /**
- * Calculates the minimum achievable velocity given a maximum voltage supply
- * and an acceleration. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the acceleration constraint, and this will give you
- * a simultaneously-achievable velocity constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the motor.
- * @param acceleration The acceleration of the motor.
- * @return The minimum possible velocity at the given acceleration.
- */
- public double minAchievableVelocity(double maxVoltage, double acceleration) {
- // Assume min velocity is negative, ks flips sign
- return (-maxVoltage + ks - acceleration * ka) / kv;
- }
-
- /**
- * Calculates the maximum achievable acceleration given a maximum voltage
- * supply and a velocity. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the velocity constraint, and this will give you
- * a simultaneously-achievable acceleration constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the motor.
- * @param velocity The velocity of the motor.
- * @return The maximum possible acceleration at the given velocity.
- */
- public double maxAchievableAcceleration(double maxVoltage, double velocity) {
- return (maxVoltage - ks * Math.signum(velocity) - velocity * kv) / ka;
- }
-
- /**
- * Calculates the maximum achievable acceleration given a maximum voltage
- * supply and a velocity. Useful for ensuring that velocity and
- * acceleration constraints for a trapezoidal profile are simultaneously
- * achievable - enter the velocity constraint, and this will give you
- * a simultaneously-achievable acceleration constraint.
- *
- * @param maxVoltage The maximum voltage that can be supplied to the motor.
- * @param velocity The velocity of the motor.
- * @return The minimum possible acceleration at the given velocity.
- */
- public double minAchievableAcceleration(double maxVoltage, double velocity) {
- return maxAchievableAcceleration(-maxVoltage, velocity);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java
deleted file mode 100644
index 7474b02..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java
+++ /dev/null
@@ -1,288 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import java.util.function.BiFunction;
-
-import edu.wpi.first.math.Drake;
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpilibj.system.NumericalJacobian;
-import edu.wpi.first.wpilibj.system.RungeKutta;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-/**
- * Kalman filters combine predictions from a model and measurements to give an estimate of the true
- * system state. This is useful because many states cannot be measured directly as a result of
- * sensor noise, or because the state is "hidden".
- *
- * <p>The Extended Kalman filter is just like the {@link KalmanFilter Kalman filter}, but we make a
- * linear approximation of nonlinear dynamics and/or nonlinear measurement models. This means that
- * the EKF works with nonlinear systems.
- */
-@SuppressWarnings("ClassTypeParameterName")
-public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
- implements KalmanTypeFilter<States, Inputs, Outputs> {
- private final Nat<States> m_states;
- private final Nat<Outputs> m_outputs;
-
- @SuppressWarnings("MemberName")
- private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
- @SuppressWarnings("MemberName")
- private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
- private final Matrix<States, States> m_contQ;
- private final Matrix<States, States> m_initP;
- private final Matrix<Outputs, Outputs> m_contR;
- @SuppressWarnings("MemberName")
- private Matrix<States, N1> m_xHat;
- @SuppressWarnings("MemberName")
- private Matrix<States, States> m_P;
- private double m_dtSeconds;
-
- /**
- * Constructs an extended Kalman filter.
- *
- * @param states a Nat representing the number of states.
- * @param inputs a Nat representing the number of inputs.
- * @param outputs a Nat representing the number of outputs.
- * @param f A vector-valued function of x and u that returns
- * the derivative of the state vector.
- * @param h A vector-valued function of x and u that returns
- * the measurement vector.
- * @param stateStdDevs Standard deviations of model states.
- * @param measurementStdDevs Standard deviations of measurements.
- * @param dtSeconds Nominal discretization timestep.
- */
- @SuppressWarnings("ParameterName")
- public ExtendedKalmanFilter(
- Nat<States> states,
- Nat<Inputs> inputs,
- Nat<Outputs> outputs,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
- Matrix<States, N1> stateStdDevs,
- Matrix<Outputs, N1> measurementStdDevs,
- double dtSeconds
- ) {
- m_states = states;
- m_outputs = outputs;
-
- m_f = f;
- m_h = h;
-
- m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
- this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
- m_dtSeconds = dtSeconds;
-
- reset();
-
- final var contA = NumericalJacobian
- .numericalJacobianX(states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1()));
- final var C = NumericalJacobian
- .numericalJacobianX(outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1()));
-
- final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
- final var discA = discPair.getFirst();
- final var discQ = discPair.getSecond();
-
- final var discR = Discretization.discretizeR(m_contR, dtSeconds);
-
- // IsStabilizable(A^T, C^T) will tell us if the system is observable.
- boolean isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), C.transpose());
- if (isObservable && outputs.getNum() <= states.getNum()) {
- m_initP = Drake.discreteAlgebraicRiccatiEquation(
- discA.transpose(), C.transpose(), discQ, discR) ;
- } else {
- m_initP = new Matrix<>(states, states);
- }
-
- m_P = m_initP;
- }
-
- /**
- * Returns the error covariance matrix P.
- *
- * @return the error covariance matrix P.
- */
- @Override
- public Matrix<States, States> getP() {
- return m_P;
- }
-
- /**
- * Returns an element of the error covariance matrix P.
- *
- * @param row Row of P.
- * @param col Column of P.
- * @return the value of the error covariance matrix P at (i, j).
- */
- @Override
- public double getP(int row, int col) {
- return m_P.get(row, col);
- }
-
- /**
- * Sets the entire error covariance matrix P.
- *
- * @param newP The new value of P to use.
- */
- @Override
- public void setP(Matrix<States, States> newP) {
- m_P = newP;
- }
-
- /**
- * Returns the state estimate x-hat.
- *
- * @return the state estimate x-hat.
- */
- @Override
- public Matrix<States, N1> getXhat() {
- return m_xHat;
- }
-
- /**
- * Returns an element of the state estimate x-hat.
- *
- * @param row Row of x-hat.
- * @return the value of the state estimate x-hat at i.
- */
- @Override
- public double getXhat(int row) {
- return m_xHat.get(row, 0);
- }
-
- /**
- * Set initial state estimate x-hat.
- *
- * @param xHat The state estimate x-hat.
- */
- @SuppressWarnings("ParameterName")
- @Override
- public void setXhat(Matrix<States, N1> xHat) {
- m_xHat = xHat;
- }
-
-
- /**
- * Set an element of the initial state estimate x-hat.
- *
- * @param row Row of x-hat.
- * @param value Value for element of x-hat.
- */
- @Override
- public void setXhat(int row, double value) {
- m_xHat.set(row, 0, value);
- }
-
- @Override
- public void reset() {
- m_xHat = new Matrix<>(m_states, Nat.N1());
- m_P = m_initP;
- }
-
- /**
- * Project the model into the future with a new control input u.
- *
- * @param u New control input from controller.
- * @param dtSeconds Timestep for prediction.
- */
- @SuppressWarnings("ParameterName")
- @Override
- public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
- predict(u, m_f, dtSeconds);
- }
-
- /**
- * Project the model into the future with a new control input u.
- *
- * @param u New control input from controller.
- * @param f The function used to linearlize the model.
- * @param dtSeconds Timestep for prediction.
- */
- @SuppressWarnings("ParameterName")
- public void predict(
- Matrix<Inputs, N1> u, BiFunction<Matrix<States, N1>,
- Matrix<Inputs, N1>, Matrix<States, N1>> f,
- double dtSeconds
- ) {
- // Find continuous A
- final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u);
-
- // Find discrete A and Q
- final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
- final var discA = discPair.getFirst();
- final var discQ = discPair.getSecond();
-
- m_xHat = RungeKutta.rungeKutta(f, m_xHat, u, dtSeconds);
- m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
- m_dtSeconds = dtSeconds;
- }
-
- /**
- * Correct the state estimate x-hat using the measurements in y.
- *
- * @param u Same control input used in the predict step.
- * @param y Measurement vector.
- */
- @SuppressWarnings("ParameterName")
- @Override
- public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
- correct(m_outputs, u, y, m_h, m_contR);
- }
-
- /**
- * Correct the state estimate x-hat using the measurements in y.
- *
- * <p>This is useful for when the measurements available during a timestep's
- * Correct() call vary. The h(x, u) passed to the constructor is used if one is
- * not provided (the two-argument version of this function).
- *
- * @param <Rows> Number of rows in the result of f(x, u).
- * @param rows Number of rows in the result of f(x, u).
- * @param u Same control input used in the predict step.
- * @param y Measurement vector.
- * @param h A vector-valued function of x and u that returns the measurement
- * vector.
- * @param R Discrete measurement noise covariance matrix.
- */
- @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public <Rows extends Num> void correct(
- Nat<Rows> rows, Matrix<Inputs, N1> u,
- Matrix<Rows, N1> y,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
- Matrix<Rows, Rows> R
- ) {
- final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
- final var discR = Discretization.discretizeR(R, m_dtSeconds);
-
- final var S = C.times(m_P).times(C.transpose()).plus(discR);
-
- // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
- // efficiently.
- //
- // K = PC^T S^-1
- // KS = PC^T
- // (KS)^T = (PC^T)^T
- // S^T K^T = CP^T
- //
- // The solution of Ax = b can be found via x = A.solve(b).
- //
- // K^T = S^T.solve(CP^T)
- // K = (S^T.solve(CP^T))^T
- //
- // Now we have the Kalman gain
- final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
-
- m_xHat = m_xHat.plus(K.times(y.minus(h.apply(m_xHat, u))));
- m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanTypeFilter.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanTypeFilter.java
deleted file mode 100644
index aa93b29..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanTypeFilter.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-@SuppressWarnings({"ParameterName", "InterfaceTypeParameterName"})
-interface KalmanTypeFilter<States extends Num, Inputs extends Num, Outputs extends Num> {
- Matrix<States, States> getP();
-
- double getP(int i, int j);
-
- void setP(Matrix<States, States> newP);
-
- Matrix<States, N1> getXhat();
-
- double getXhat(int i);
-
- void setXhat(Matrix<States, N1> xHat);
-
- void setXhat(int i, double value);
-
- void reset();
-
- void predict(Matrix<Inputs, N1> u, double dtSeconds);
-
- void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y);
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/UnscentedKalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/UnscentedKalmanFilter.java
deleted file mode 100644
index ca99153..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/UnscentedKalmanFilter.java
+++ /dev/null
@@ -1,316 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import java.util.function.BiFunction;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpilibj.system.NumericalJacobian;
-import edu.wpi.first.wpilibj.system.RungeKutta;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.Pair;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-/**
- * A Kalman filter combines predictions from a model and measurements to give an estimate of the
- * true ystem state. This is useful because many states cannot be measured directly as a result of
- * sensor noise, or because the state is "hidden".
- *
- * <p>The Unscented Kalman filter is similar to the {@link KalmanFilter Kalman filter}, except that
- * it propagates carefully chosen points called sigma points through the non-linear model to obtain
- * an estimate of the true covariance (as opposed to a linearized version of it). This means that
- * the UKF works with nonlinear systems.
- */
-@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
-public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
- Outputs extends Num> implements KalmanTypeFilter<States, Inputs, Outputs> {
-
- private final Nat<States> m_states;
- private final Nat<Outputs> m_outputs;
-
- private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
- private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
-
- private Matrix<States, N1> m_xHat;
- private Matrix<States, States> m_P;
- private final Matrix<States, States> m_contQ;
- private final Matrix<Outputs, Outputs> m_contR;
- private Matrix<States, ?> m_sigmasF;
- private double m_dtSeconds;
-
- private final MerweScaledSigmaPoints<States> m_pts;
-
- /**
- * Constructs an Unscented Kalman Filter.
- *
- * @param states A Nat representing the number of states.
- * @param outputs A Nat representing the number of outputs.
- * @param f A vector-valued function of x and u that returns
- * the derivative of the state vector.
- * @param h A vector-valued function of x and u that returns
- * the measurement vector.
- * @param stateStdDevs Standard deviations of model states.
- * @param measurementStdDevs Standard deviations of measurements.
- * @param dtSeconds Nominal discretization timestep.
- */
- @SuppressWarnings("ParameterName")
- public UnscentedKalmanFilter(Nat<States> states, Nat<Outputs> outputs,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>,
- Matrix<States, N1>> f,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>,
- Matrix<Outputs, N1>> h,
- Matrix<States, N1> stateStdDevs,
- Matrix<Outputs, N1> measurementStdDevs,
- double dtSeconds) {
- this.m_states = states;
- this.m_outputs = outputs;
-
- m_f = f;
- m_h = h;
-
- m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
- m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
-
- m_dtSeconds = dtSeconds;
-
- m_pts = new MerweScaledSigmaPoints<>(states);
-
- reset();
- }
-
- @SuppressWarnings({"ParameterName", "LocalVariableName", "PMD.CyclomaticComplexity"})
- static <S extends Num, C extends Num>
- Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
- Nat<S> s, Nat<C> dim, Matrix<C, ?> sigmas, Matrix<?, N1> Wm, Matrix<?, N1> Wc
- ) {
- if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
- throw new IllegalArgumentException("Sigmas must be covDim by 2 * states + 1! Got "
- + sigmas.getNumRows() + " by " + sigmas.getNumCols());
- }
-
- if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1 ) {
- throw new IllegalArgumentException("Wm must be 2 * states + 1 by 1! Got "
- + Wm.getNumRows() + " by " + Wm.getNumCols());
- }
-
- if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
- throw new IllegalArgumentException("Wc must be 2 * states + 1 by 1! Got "
- + Wc.getNumRows() + " by " + Wc.getNumCols());
- }
-
- // New mean is just the sum of the sigmas * weight
- // dot = \Sigma^n_1 (W[k]*Xi[k])
- Matrix<C, N1> x = sigmas.times(Matrix.changeBoundsUnchecked(Wm));
-
- // New covariance is the sum of the outer product of the residuals times the
- // weights
- Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
- for (int i = 0; i < 2 * s.getNum() + 1; i++) {
- y.setColumn(i, sigmas.extractColumnVector(i).minus(x));
- }
- Matrix<C, C> P = y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
- .times(Matrix.changeBoundsUnchecked(y.transpose()));
-
- return new Pair<>(x, P);
- }
-
- /**
- * Returns the error covariance matrix P.
- *
- * @return the error covariance matrix P.
- */
- @Override
- public Matrix<States, States> getP() {
- return m_P;
- }
-
- /**
- * Returns an element of the error covariance matrix P.
- *
- * @param row Row of P.
- * @param col Column of P.
- * @return the value of the error covariance matrix P at (i, j).
- */
- @Override
- public double getP(int row, int col) {
- return m_P.get(row, col);
- }
-
- /**
- * Sets the entire error covariance matrix P.
- *
- * @param newP The new value of P to use.
- */
- @Override
- public void setP(Matrix<States, States> newP) {
- m_P = newP;
- }
-
- /**
- * Returns the state estimate x-hat.
- *
- * @return the state estimate x-hat.
- */
- @Override
- public Matrix<States, N1> getXhat() {
- return m_xHat;
- }
-
- /**
- * Returns an element of the state estimate x-hat.
- *
- * @param row Row of x-hat.
- * @return the value of the state estimate x-hat at i.
- */
- @Override
- public double getXhat(int row) {
- return m_xHat.get(row, 0);
- }
-
-
- /**
- * Set initial state estimate x-hat.
- *
- * @param xHat The state estimate x-hat.
- */
- @SuppressWarnings("ParameterName")
- @Override
- public void setXhat(Matrix<States, N1> xHat) {
- m_xHat = xHat;
- }
-
- /**
- * Set an element of the initial state estimate x-hat.
- *
- * @param row Row of x-hat.
- * @param value Value for element of x-hat.
- */
- @Override
- public void setXhat(int row, double value) {
- m_xHat.set(row, 0, value);
- }
-
- /**
- * Resets the observer.
- */
- @Override
- public void reset() {
- m_xHat = new Matrix<>(m_states, Nat.N1());
- m_P = new Matrix<>(m_states, m_states);
- m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
- }
-
- /**
- * Project the model into the future with a new control input u.
- *
- * @param u New control input from controller.
- * @param dtSeconds Timestep for prediction.
- */
- @SuppressWarnings({"LocalVariableName", "ParameterName"})
- @Override
- public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
- // Discretize Q before projecting mean and covariance forward
- Matrix<States, States> contA =
- NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
- var discQ =
- Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond();
-
- var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
-
- for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
- Matrix<States, N1> x = sigmas.extractColumnVector(i);
-
- m_sigmasF.setColumn(i, RungeKutta.rungeKutta(m_f, x, u, dtSeconds));
- }
-
- var ret = unscentedTransform(m_states, m_states,
- m_sigmasF, m_pts.getWm(), m_pts.getWc());
-
- m_xHat = ret.getFirst();
- m_P = ret.getSecond().plus(discQ);
- m_dtSeconds = dtSeconds;
- }
-
- /**
- * Correct the state estimate x-hat using the measurements in y.
- *
- * @param u Same control input used in the predict step.
- * @param y Measurement vector.
- */
- @SuppressWarnings("ParameterName")
- @Override
- public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
- correct(m_outputs, u, y, m_h, m_contR);
- }
-
- /**
- * Correct the state estimate x-hat using the measurements in y.
- *
- * <p>This is useful for when the measurements available during a timestep's
- * Correct() call vary. The h(x, u) passed to the constructor is used if one
- * is not provided (the two-argument version of this function).
- *
- * @param u Same control input used in the predict step.
- * @param y Measurement vector.
- * @param h A vector-valued function of x and u that returns
- * the measurement vector.
- * @param R Measurement noise covariance matrix.
- */
- @SuppressWarnings({"ParameterName", "LocalVariableName"})
- public <R extends Num> void correct(
- Nat<R> rows, Matrix<Inputs, N1> u,
- Matrix<R, N1> y,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
- Matrix<R, R> R) {
- final var discR = Discretization.discretizeR(R, m_dtSeconds);
-
- // Transform sigma points into measurement space
- Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(
- rows.getNum(), 2 * m_states.getNum() + 1));
- var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
- for (int i = 0; i < m_pts.getNumSigmas(); i++) {
- Matrix<R, N1> hRet = h.apply(
- sigmas.extractColumnVector(i),
- u
- );
- sigmasH.setColumn(i, hRet);
- }
-
- // Mean and covariance of prediction passed through unscented transform
- var transRet = unscentedTransform(m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc());
- var yHat = transRet.getFirst();
- var Py = transRet.getSecond().plus(discR);
-
- // Compute cross covariance of the state and the measurements
- Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
- for (int i = 0; i < m_pts.getNumSigmas(); i++) {
- var temp =
- m_sigmasF.extractColumnVector(i).minus(m_xHat)
- .times(sigmasH.extractColumnVector(i).minus(yHat).transpose());
-
- Pxy = Pxy.plus(temp.times(m_pts.getWc(i)));
- }
-
- // K = P_{xy} Py^-1
- // K^T = P_y^T^-1 P_{xy}^T
- // P_y^T K^T = P_{xy}^T
- // K^T = P_y^T.solve(P_{xy}^T)
- // K = (P_y^T.solve(P_{xy}^T)^T
- Matrix<States, R> K = new Matrix<>(
- Py.transpose().solve(Pxy.transpose()).transpose()
- );
-
- m_xHat = m_xHat.plus(K.times(y.minus(yHat)));
- m_P = m_P.minus(K.times(Py).times(K.transpose()));
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/ChassisSpeeds.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/ChassisSpeeds.java
deleted file mode 100644
index a6878b3..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/ChassisSpeeds.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-
-/**
- * Represents the speed of a robot chassis. Although this struct contains
- * similar members compared to a Twist2d, they do NOT represent the same thing.
- * Whereas a Twist2d represents a change in pose w.r.t to the robot frame of reference,
- * this ChassisSpeeds struct represents a velocity w.r.t to the robot frame of
- * reference.
- *
- * <p>A strictly non-holonomic drivetrain, such as a differential drive, should
- * never have a dy component because it can never move sideways. Holonomic
- * drivetrains such as swerve and mecanum will often have all three components.
- */
-@SuppressWarnings("MemberName")
-public class ChassisSpeeds {
- /**
- * Represents forward velocity w.r.t the robot frame of reference. (Fwd is +)
- */
- public double vxMetersPerSecond;
-
- /**
- * Represents sideways velocity w.r.t the robot frame of reference. (Left is +)
- */
- public double vyMetersPerSecond;
-
- /**
- * Represents the angular velocity of the robot frame. (CCW is +)
- */
- public double omegaRadiansPerSecond;
-
- /**
- * Constructs a ChassisSpeeds with zeros for dx, dy, and theta.
- */
- public ChassisSpeeds() {
- }
-
- /**
- * Constructs a ChassisSpeeds object.
- *
- * @param vxMetersPerSecond Forward velocity.
- * @param vyMetersPerSecond Sideways velocity.
- * @param omegaRadiansPerSecond Angular velocity.
- */
- public ChassisSpeeds(double vxMetersPerSecond, double vyMetersPerSecond,
- double omegaRadiansPerSecond) {
- this.vxMetersPerSecond = vxMetersPerSecond;
- this.vyMetersPerSecond = vyMetersPerSecond;
- this.omegaRadiansPerSecond = omegaRadiansPerSecond;
- }
-
- /**
- * Converts a user provided field-relative set of speeds into a robot-relative
- * ChassisSpeeds object.
- *
- * @param vxMetersPerSecond The component of speed in the x direction relative to the field.
- * Positive x is away from your alliance wall.
- * @param vyMetersPerSecond The component of speed in the y direction relative to the field.
- * Positive y is to your left when standing behind the alliance wall.
- * @param omegaRadiansPerSecond The angular rate of the robot.
- * @param robotAngle The angle of the robot as measured by a gyroscope. The robot's
- * angle is considered to be zero when it is facing directly away
- * from your alliance station wall. Remember that this should
- * be CCW positive.
- * @return ChassisSpeeds object representing the speeds in the robot's frame of reference.
- */
- public static ChassisSpeeds fromFieldRelativeSpeeds(
- double vxMetersPerSecond, double vyMetersPerSecond,
- double omegaRadiansPerSecond, Rotation2d robotAngle) {
- return new ChassisSpeeds(
- vxMetersPerSecond * robotAngle.getCos() + vyMetersPerSecond * robotAngle.getSin(),
- -vxMetersPerSecond * robotAngle.getSin() + vyMetersPerSecond * robotAngle.getCos(),
- omegaRadiansPerSecond
- );
- }
-
- @Override
- public String toString() {
- return String.format("ChassisSpeeds(Vx: %.2f m/s, Vy: %.2f m/s, Omega: %.2f rad/s)",
- vxMetersPerSecond, vyMetersPerSecond, omegaRadiansPerSecond);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveKinematics.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveKinematics.java
deleted file mode 100644
index 309f531..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveKinematics.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-
-/**
- * Helper class that converts a chassis velocity (dx and dtheta components) to
- * left and right wheel velocities for a differential drive.
- *
- * <p>Inverse kinematics converts a desired chassis speed into left and right
- * velocity components whereas forward kinematics converts left and right
- * component velocities into a linear and angular chassis speed.
- */
-@SuppressWarnings("MemberName")
-public class DifferentialDriveKinematics {
- public final double trackWidthMeters;
-
- /**
- * Constructs a differential drive kinematics object.
- *
- * @param trackWidthMeters The track width of the drivetrain. Theoretically, this is
- * the distance between the left wheels and right wheels.
- * However, the empirical value may be larger than the physical
- * measured value due to scrubbing effects.
- */
- public DifferentialDriveKinematics(double trackWidthMeters) {
- this.trackWidthMeters = trackWidthMeters;
- MathSharedStore.reportUsage(MathUsageId.kKinematics_DifferentialDrive, 1);
- }
-
- /**
- * Returns a chassis speed from left and right component velocities using
- * forward kinematics.
- *
- * @param wheelSpeeds The left and right velocities.
- * @return The chassis speed.
- */
- public ChassisSpeeds toChassisSpeeds(DifferentialDriveWheelSpeeds wheelSpeeds) {
- return new ChassisSpeeds(
- (wheelSpeeds.leftMetersPerSecond + wheelSpeeds.rightMetersPerSecond) / 2, 0,
- (wheelSpeeds.rightMetersPerSecond - wheelSpeeds.leftMetersPerSecond)
- / trackWidthMeters
- );
- }
-
- /**
- * Returns left and right component velocities from a chassis speed using
- * inverse kinematics.
- *
- * @param chassisSpeeds The linear and angular (dx and dtheta) components that
- * represent the chassis' speed.
- * @return The left and right velocities.
- */
- public DifferentialDriveWheelSpeeds toWheelSpeeds(ChassisSpeeds chassisSpeeds) {
- return new DifferentialDriveWheelSpeeds(
- chassisSpeeds.vxMetersPerSecond - trackWidthMeters / 2
- * chassisSpeeds.omegaRadiansPerSecond,
- chassisSpeeds.vxMetersPerSecond + trackWidthMeters / 2
- * chassisSpeeds.omegaRadiansPerSecond
- );
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveOdometry.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveOdometry.java
deleted file mode 100644
index 86470f8..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveOdometry.java
+++ /dev/null
@@ -1,119 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Twist2d;
-
-/**
- * Class for differential drive odometry. Odometry allows you to track the
- * robot's position on the field over the course of a match using readings from
- * 2 encoders and a gyroscope.
- *
- * <p>Teams can use odometry during the autonomous period for complex tasks like
- * path following. Furthermore, odometry can be used for latency compensation
- * when using computer-vision systems.
- *
- * <p>It is important that you reset your encoders to zero before using this class.
- * Any subsequent pose resets also require the encoders to be reset to zero.
- */
-public class DifferentialDriveOdometry {
- private Pose2d m_poseMeters;
-
- private Rotation2d m_gyroOffset;
- private Rotation2d m_previousAngle;
-
- private double m_prevLeftDistance;
- private double m_prevRightDistance;
-
- /**
- * Constructs a DifferentialDriveOdometry object.
- *
- * @param gyroAngle The angle reported by the gyroscope.
- * @param initialPoseMeters The starting position of the robot on the field.
- */
- public DifferentialDriveOdometry(Rotation2d gyroAngle,
- Pose2d initialPoseMeters) {
- m_poseMeters = initialPoseMeters;
- m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
- m_previousAngle = initialPoseMeters.getRotation();
- MathSharedStore.reportUsage(MathUsageId.kOdometry_DifferentialDrive, 1);
- }
-
- /**
- * Constructs a DifferentialDriveOdometry object with the default pose at the origin.
- *
- * @param gyroAngle The angle reported by the gyroscope.
- */
- public DifferentialDriveOdometry(Rotation2d gyroAngle) {
- this(gyroAngle, new Pose2d());
- }
-
- /**
- * Resets the robot's position on the field.
- *
- * <p>You NEED to reset your encoders (to zero) when calling this method.
- *
- * <p>The gyroscope angle does not need to be reset here on the user's robot code.
- * The library automatically takes care of offsetting the gyro angle.
- *
- * @param poseMeters The position on the field that your robot is at.
- * @param gyroAngle The angle reported by the gyroscope.
- */
- public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
- m_poseMeters = poseMeters;
- m_previousAngle = poseMeters.getRotation();
- m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
-
- m_prevLeftDistance = 0.0;
- m_prevRightDistance = 0.0;
- }
-
- /**
- * Returns the position of the robot on the field.
- *
- * @return The pose of the robot (x and y are in meters).
- */
- public Pose2d getPoseMeters() {
- return m_poseMeters;
- }
-
-
- /**
- * Updates the robot position on the field using distance measurements from encoders. This
- * method is more numerically accurate than using velocities to integrate the pose and
- * is also advantageous for teams that are using lower CPR encoders.
- *
- * @param gyroAngle The angle reported by the gyroscope.
- * @param leftDistanceMeters The distance traveled by the left encoder.
- * @param rightDistanceMeters The distance traveled by the right encoder.
- * @return The new pose of the robot.
- */
- public Pose2d update(Rotation2d gyroAngle, double leftDistanceMeters,
- double rightDistanceMeters) {
- double deltaLeftDistance = leftDistanceMeters - m_prevLeftDistance;
- double deltaRightDistance = rightDistanceMeters - m_prevRightDistance;
-
- m_prevLeftDistance = leftDistanceMeters;
- m_prevRightDistance = rightDistanceMeters;
-
- double averageDeltaDistance = (deltaLeftDistance + deltaRightDistance) / 2.0;
- var angle = gyroAngle.plus(m_gyroOffset);
-
- var newPose = m_poseMeters.exp(
- new Twist2d(averageDeltaDistance, 0.0, angle.minus(m_previousAngle).getRadians()));
-
- m_previousAngle = angle;
-
- m_poseMeters = new Pose2d(newPose.getTranslation(), angle);
- return m_poseMeters;
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveWheelSpeeds.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveWheelSpeeds.java
deleted file mode 100644
index 1716430..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveWheelSpeeds.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-/**
- * Represents the wheel speeds for a differential drive drivetrain.
- */
-@SuppressWarnings("MemberName")
-public class DifferentialDriveWheelSpeeds {
- /**
- * Speed of the left side of the robot.
- */
- public double leftMetersPerSecond;
-
- /**
- * Speed of the right side of the robot.
- */
- public double rightMetersPerSecond;
-
- /**
- * Constructs a DifferentialDriveWheelSpeeds with zeros for left and right speeds.
- */
- public DifferentialDriveWheelSpeeds() {
- }
-
- /**
- * Constructs a DifferentialDriveWheelSpeeds.
- *
- * @param leftMetersPerSecond The left speed.
- * @param rightMetersPerSecond The right speed.
- */
- public DifferentialDriveWheelSpeeds(double leftMetersPerSecond, double rightMetersPerSecond) {
- this.leftMetersPerSecond = leftMetersPerSecond;
- this.rightMetersPerSecond = rightMetersPerSecond;
- }
-
- /**
- * Normalizes the wheel speeds using some max attainable speed. Sometimes,
- * after inverse kinematics, the requested speed from a/several modules may be
- * above the max attainable speed for the driving motor on that module. To fix
- * this issue, one can "normalize" all the wheel speeds to make sure that all
- * requested module speeds are below the absolute threshold, while maintaining
- * the ratio of speeds between modules.
- *
- * @param attainableMaxSpeedMetersPerSecond The absolute max speed that a wheel can reach.
- */
- public void normalize(double attainableMaxSpeedMetersPerSecond) {
- double realMaxSpeed = Math.max(Math.abs(leftMetersPerSecond), Math.abs(rightMetersPerSecond));
-
- if (realMaxSpeed > attainableMaxSpeedMetersPerSecond) {
- leftMetersPerSecond = leftMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- rightMetersPerSecond = rightMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- }
- }
-
- @Override
- public String toString() {
- return String.format("DifferentialDriveWheelSpeeds(Left: %.2f m/s, Right: %.2f m/s)",
- leftMetersPerSecond, rightMetersPerSecond);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveKinematics.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveKinematics.java
deleted file mode 100644
index 8c1d5d3..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveKinematics.java
+++ /dev/null
@@ -1,172 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
-/**
- * Helper class that converts a chassis velocity (dx, dy, and dtheta components)
- * into individual wheel speeds.
- *
- * <p>The inverse kinematics (converting from a desired chassis velocity to
- * individual wheel speeds) uses the relative locations of the wheels with
- * respect to the center of rotation. The center of rotation for inverse
- * kinematics is also variable. This means that you can set your set your center
- * of rotation in a corner of the robot to perform special evasion maneuvers.
- *
- * <p>Forward kinematics (converting an array of wheel speeds into the overall
- * chassis motion) is performs the exact opposite of what inverse kinematics
- * does. Since this is an overdetermined system (more equations than variables),
- * we use a least-squares approximation.
- *
- * <p>The inverse kinematics: [wheelSpeeds] = [wheelLocations] * [chassisSpeeds]
- * We take the Moore-Penrose pseudoinverse of [wheelLocations] and then
- * multiply by [wheelSpeeds] to get our chassis speeds.
- *
- * <p>Forward kinematics is also used for odometry -- determining the position of
- * the robot on the field using encoders and a gyro.
- */
-public class MecanumDriveKinematics {
- private SimpleMatrix m_inverseKinematics;
- private final SimpleMatrix m_forwardKinematics;
-
- private final Translation2d m_frontLeftWheelMeters;
- private final Translation2d m_frontRightWheelMeters;
- private final Translation2d m_rearLeftWheelMeters;
- private final Translation2d m_rearRightWheelMeters;
-
- private Translation2d m_prevCoR = new Translation2d();
-
- /**
- * Constructs a mecanum drive kinematics object.
- *
- * @param frontLeftWheelMeters The location of the front-left wheel relative to the
- * physical center of the robot.
- * @param frontRightWheelMeters The location of the front-right wheel relative to
- * the physical center of the robot.
- * @param rearLeftWheelMeters The location of the rear-left wheel relative to the
- * physical center of the robot.
- * @param rearRightWheelMeters The location of the rear-right wheel relative to the
- * physical center of the robot.
- */
- public MecanumDriveKinematics(Translation2d frontLeftWheelMeters,
- Translation2d frontRightWheelMeters,
- Translation2d rearLeftWheelMeters,
- Translation2d rearRightWheelMeters) {
- m_frontLeftWheelMeters = frontLeftWheelMeters;
- m_frontRightWheelMeters = frontRightWheelMeters;
- m_rearLeftWheelMeters = rearLeftWheelMeters;
- m_rearRightWheelMeters = rearRightWheelMeters;
-
- m_inverseKinematics = new SimpleMatrix(4, 3);
-
- setInverseKinematics(frontLeftWheelMeters, frontRightWheelMeters,
- rearLeftWheelMeters, rearRightWheelMeters);
- m_forwardKinematics = m_inverseKinematics.pseudoInverse();
-
- MathSharedStore.reportUsage(MathUsageId.kKinematics_MecanumDrive, 1);
- }
-
- /**
- * Performs inverse kinematics to return the wheel speeds from a desired chassis velocity. This
- * method is often used to convert joystick values into wheel speeds.
- *
- * <p>This function also supports variable centers of rotation. During normal
- * operations, the center of rotation is usually the same as the physical
- * center of the robot; therefore, the argument is defaulted to that use case.
- * However, if you wish to change the center of rotation for evasive
- * maneuvers, vision alignment, or for any other use case, you can do so.
- *
- * @param chassisSpeeds The desired chassis speed.
- * @param centerOfRotationMeters The center of rotation. For example, if you set the
- * center of rotation at one corner of the robot and provide
- * a chassis speed that only has a dtheta component, the robot
- * will rotate around that corner.
- * @return The wheel speeds. Use caution because they are not normalized. Sometimes, a user
- * input may cause one of the wheel speeds to go above the attainable max velocity. Use
- * the {@link MecanumDriveWheelSpeeds#normalize(double)} function to rectify this issue.
- */
- public MecanumDriveWheelSpeeds toWheelSpeeds(ChassisSpeeds chassisSpeeds,
- Translation2d centerOfRotationMeters) {
- // We have a new center of rotation. We need to compute the matrix again.
- if (!centerOfRotationMeters.equals(m_prevCoR)) {
- var fl = m_frontLeftWheelMeters.minus(centerOfRotationMeters);
- var fr = m_frontRightWheelMeters.minus(centerOfRotationMeters);
- var rl = m_rearLeftWheelMeters.minus(centerOfRotationMeters);
- var rr = m_rearRightWheelMeters.minus(centerOfRotationMeters);
-
- setInverseKinematics(fl, fr, rl, rr);
- m_prevCoR = centerOfRotationMeters;
- }
-
- var chassisSpeedsVector = new SimpleMatrix(3, 1);
- chassisSpeedsVector.setColumn(0, 0,
- chassisSpeeds.vxMetersPerSecond, chassisSpeeds.vyMetersPerSecond,
- chassisSpeeds.omegaRadiansPerSecond);
-
- var wheelsMatrix = m_inverseKinematics.mult(chassisSpeedsVector);
- return new MecanumDriveWheelSpeeds(
- wheelsMatrix.get(0, 0),
- wheelsMatrix.get(1, 0),
- wheelsMatrix.get(2, 0),
- wheelsMatrix.get(3, 0)
- );
- }
-
- /**
- * Performs inverse kinematics. See {@link #toWheelSpeeds(ChassisSpeeds, Translation2d)} for more
- * information.
- *
- * @param chassisSpeeds The desired chassis speed.
- * @return The wheel speeds.
- */
- public MecanumDriveWheelSpeeds toWheelSpeeds(ChassisSpeeds chassisSpeeds) {
- return toWheelSpeeds(chassisSpeeds, new Translation2d());
- }
-
- /**
- * Performs forward kinematics to return the resulting chassis state from the given wheel speeds.
- * This method is often used for odometry -- determining the robot's position on the field using
- * data from the real-world speed of each wheel on the robot.
- *
- * @param wheelSpeeds The current mecanum drive wheel speeds.
- * @return The resulting chassis speed.
- */
- public ChassisSpeeds toChassisSpeeds(MecanumDriveWheelSpeeds wheelSpeeds) {
- var wheelSpeedsMatrix = new SimpleMatrix(4, 1);
- wheelSpeedsMatrix.setColumn(0, 0,
- wheelSpeeds.frontLeftMetersPerSecond, wheelSpeeds.frontRightMetersPerSecond,
- wheelSpeeds.rearLeftMetersPerSecond, wheelSpeeds.rearRightMetersPerSecond
- );
- var chassisSpeedsVector = m_forwardKinematics.mult(wheelSpeedsMatrix);
-
- return new ChassisSpeeds(chassisSpeedsVector.get(0, 0), chassisSpeedsVector.get(1, 0),
- chassisSpeedsVector.get(2, 0));
- }
-
- /**
- * Construct inverse kinematics matrix from wheel locations.
- *
- * @param fl The location of the front-left wheel relative to the physical center of the robot.
- * @param fr The location of the front-right wheel relative to the physical center of the robot.
- * @param rl The location of the rear-left wheel relative to the physical center of the robot.
- * @param rr The location of the rear-right wheel relative to the physical center of the robot.
- */
- private void setInverseKinematics(Translation2d fl, Translation2d fr,
- Translation2d rl, Translation2d rr) {
- m_inverseKinematics.setRow(0, 0, 1, -1, -(fl.getX() + fl.getY()));
- m_inverseKinematics.setRow(1, 0, 1, 1, fr.getX() - fr.getY());
- m_inverseKinematics.setRow(2, 0, 1, 1, rl.getX() - rl.getY());
- m_inverseKinematics.setRow(3, 0, 1, -1, -(rr.getX() + rr.getY()));
- m_inverseKinematics = m_inverseKinematics.scale(1.0 / Math.sqrt(2));
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveMotorVoltages.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveMotorVoltages.java
deleted file mode 100644
index 756bb60..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveMotorVoltages.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-/**
- * Represents the motor voltages for a mecanum drive drivetrain.
- */
-@SuppressWarnings("MemberName")
-public class MecanumDriveMotorVoltages {
- /**
- * Voltage of the front left motor.
- */
- public double frontLeftVoltage;
-
- /**
- * Voltage of the front right motor.
- */
- public double frontRightVoltage;
-
- /**
- * Voltage of the rear left motor.
- */
- public double rearLeftVoltage;
-
- /**
- * Voltage of the rear right motor.
- */
- public double rearRightVoltage;
-
- /**
- * Constructs a MecanumDriveMotorVoltages with zeros for all member fields.
- */
- public MecanumDriveMotorVoltages() {
- }
-
- /**
- * Constructs a MecanumDriveMotorVoltages.
- *
- * @param frontLeftVoltage Voltage of the front left motor.
- * @param frontRightVoltage Voltage of the front right motor.
- * @param rearLeftVoltage Voltage of the rear left motor.
- * @param rearRightVoltage Voltage of the rear right motor.
- */
- public MecanumDriveMotorVoltages(double frontLeftVoltage,
- double frontRightVoltage,
- double rearLeftVoltage,
- double rearRightVoltage) {
- this.frontLeftVoltage = frontLeftVoltage;
- this.frontRightVoltage = frontRightVoltage;
- this.rearLeftVoltage = rearLeftVoltage;
- this.rearRightVoltage = rearRightVoltage;
- }
-
- @Override
- public String toString() {
- return String.format("MecanumDriveMotorVoltages(Front Left: %.2f V, Front Right: %.2f V, "
- + "Rear Left: %.2f V, Rear Right: %.2f V)", frontLeftVoltage, frontRightVoltage,
- rearLeftVoltage, rearRightVoltage);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveOdometry.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveOdometry.java
deleted file mode 100644
index cd84bdf..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveOdometry.java
+++ /dev/null
@@ -1,132 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Twist2d;
-import edu.wpi.first.wpiutil.WPIUtilJNI;
-
-/**
- * Class for mecanum drive odometry. Odometry allows you to track the robot's
- * position on the field over a course of a match using readings from your
- * mecanum wheel encoders.
- *
- * <p>Teams can use odometry during the autonomous period for complex tasks like
- * path following. Furthermore, odometry can be used for latency compensation
- * when using computer-vision systems.
- */
-public class MecanumDriveOdometry {
- private final MecanumDriveKinematics m_kinematics;
- private Pose2d m_poseMeters;
- private double m_prevTimeSeconds = -1;
-
- private Rotation2d m_gyroOffset;
- private Rotation2d m_previousAngle;
-
- /**
- * Constructs a MecanumDriveOdometry object.
- *
- * @param kinematics The mecanum drive kinematics for your drivetrain.
- * @param gyroAngle The angle reported by the gyroscope.
- * @param initialPoseMeters The starting position of the robot on the field.
- */
- public MecanumDriveOdometry(MecanumDriveKinematics kinematics, Rotation2d gyroAngle,
- Pose2d initialPoseMeters) {
- m_kinematics = kinematics;
- m_poseMeters = initialPoseMeters;
- m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
- m_previousAngle = initialPoseMeters.getRotation();
- MathSharedStore.reportUsage(MathUsageId.kOdometry_MecanumDrive, 1);
- }
-
- /**
- * Constructs a MecanumDriveOdometry object with the default pose at the origin.
- *
- * @param kinematics The mecanum drive kinematics for your drivetrain.
- * @param gyroAngle The angle reported by the gyroscope.
- */
- public MecanumDriveOdometry(MecanumDriveKinematics kinematics, Rotation2d gyroAngle) {
- this(kinematics, gyroAngle, new Pose2d());
- }
-
- /**
- * Resets the robot's position on the field.
- *
- * <p>The gyroscope angle does not need to be reset here on the user's robot code.
- * The library automatically takes care of offsetting the gyro angle.
- *
- * @param poseMeters The position on the field that your robot is at.
- * @param gyroAngle The angle reported by the gyroscope.
- */
- public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
- m_poseMeters = poseMeters;
- m_previousAngle = poseMeters.getRotation();
- m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
- }
-
- /**
- * Returns the position of the robot on the field.
- *
- * @return The pose of the robot (x and y are in meters).
- */
- public Pose2d getPoseMeters() {
- return m_poseMeters;
- }
-
- /**
- * Updates the robot's position on the field using forward kinematics and
- * integration of the pose over time. This method takes in the current time as
- * a parameter to calculate period (difference between two timestamps). The
- * period is used to calculate the change in distance from a velocity. This
- * also takes in an angle parameter which is used instead of the
- * angular rate that is calculated from forward kinematics.
- *
- * @param currentTimeSeconds The current time in seconds.
- * @param gyroAngle The angle reported by the gyroscope.
- * @param wheelSpeeds The current wheel speeds.
- * @return The new pose of the robot.
- */
- public Pose2d updateWithTime(double currentTimeSeconds, Rotation2d gyroAngle,
- MecanumDriveWheelSpeeds wheelSpeeds) {
- double period = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : 0.0;
- m_prevTimeSeconds = currentTimeSeconds;
-
- var angle = gyroAngle.plus(m_gyroOffset);
-
- var chassisState = m_kinematics.toChassisSpeeds(wheelSpeeds);
- var newPose = m_poseMeters.exp(
- new Twist2d(chassisState.vxMetersPerSecond * period,
- chassisState.vyMetersPerSecond * period,
- angle.minus(m_previousAngle).getRadians()));
-
- m_previousAngle = angle;
- m_poseMeters = new Pose2d(newPose.getTranslation(), angle);
- return m_poseMeters;
- }
-
- /**
- * Updates the robot's position on the field using forward kinematics and
- * integration of the pose over time. This method automatically calculates the
- * current time to calculate period (difference between two timestamps). The
- * period is used to calculate the change in distance from a velocity. This
- * also takes in an angle parameter which is used instead of the
- * angular rate that is calculated from forward kinematics.
- *
- * @param gyroAngle The angle reported by the gyroscope.
- * @param wheelSpeeds The current wheel speeds.
- * @return The new pose of the robot.
- */
- public Pose2d update(Rotation2d gyroAngle,
- MecanumDriveWheelSpeeds wheelSpeeds) {
- return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle,
- wheelSpeeds);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveWheelSpeeds.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveWheelSpeeds.java
deleted file mode 100644
index f00e409..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveWheelSpeeds.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import java.util.stream.DoubleStream;
-
-@SuppressWarnings("MemberName")
-public class MecanumDriveWheelSpeeds {
- /**
- * Speed of the front left wheel.
- */
- public double frontLeftMetersPerSecond;
-
- /**
- * Speed of the front right wheel.
- */
- public double frontRightMetersPerSecond;
-
- /**
- * Speed of the rear left wheel.
- */
- public double rearLeftMetersPerSecond;
-
- /**
- * Speed of the rear right wheel.
- */
- public double rearRightMetersPerSecond;
-
- /**
- * Constructs a MecanumDriveWheelSpeeds with zeros for all member fields.
- */
- public MecanumDriveWheelSpeeds() {
- }
-
- /**
- * Constructs a MecanumDriveWheelSpeeds.
- *
- * @param frontLeftMetersPerSecond Speed of the front left wheel.
- * @param frontRightMetersPerSecond Speed of the front right wheel.
- * @param rearLeftMetersPerSecond Speed of the rear left wheel.
- * @param rearRightMetersPerSecond Speed of the rear right wheel.
- */
- public MecanumDriveWheelSpeeds(double frontLeftMetersPerSecond,
- double frontRightMetersPerSecond,
- double rearLeftMetersPerSecond,
- double rearRightMetersPerSecond) {
- this.frontLeftMetersPerSecond = frontLeftMetersPerSecond;
- this.frontRightMetersPerSecond = frontRightMetersPerSecond;
- this.rearLeftMetersPerSecond = rearLeftMetersPerSecond;
- this.rearRightMetersPerSecond = rearRightMetersPerSecond;
- }
-
- /**
- * Normalizes the wheel speeds using some max attainable speed. Sometimes,
- * after inverse kinematics, the requested speed from a/several modules may be
- * above the max attainable speed for the driving motor on that module. To fix
- * this issue, one can "normalize" all the wheel speeds to make sure that all
- * requested module speeds are below the absolute threshold, while maintaining
- * the ratio of speeds between modules.
- *
- * @param attainableMaxSpeedMetersPerSecond The absolute max speed that a wheel can reach.
- */
- public void normalize(double attainableMaxSpeedMetersPerSecond) {
- double realMaxSpeed = DoubleStream.of(frontLeftMetersPerSecond,
- frontRightMetersPerSecond, rearLeftMetersPerSecond, rearRightMetersPerSecond)
- .max().getAsDouble();
-
- if (realMaxSpeed > attainableMaxSpeedMetersPerSecond) {
- frontLeftMetersPerSecond = frontLeftMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- frontRightMetersPerSecond = frontRightMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- rearLeftMetersPerSecond = rearLeftMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- rearRightMetersPerSecond = rearRightMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- }
- }
-
- @Override
- public String toString() {
- return String.format("MecanumDriveWheelSpeeds(Front Left: %.2f m/s, Front Right: %.2f m/s, "
- + "Rear Left: %.2f m/s, Rear Right: %.2f m/s)", frontLeftMetersPerSecond,
- frontRightMetersPerSecond, rearLeftMetersPerSecond, rearRightMetersPerSecond);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveKinematics.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveKinematics.java
deleted file mode 100644
index a1dba43..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveKinematics.java
+++ /dev/null
@@ -1,199 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import java.util.Arrays;
-import java.util.Collections;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
-/**
- * Helper class that converts a chassis velocity (dx, dy, and dtheta components)
- * into individual module states (speed and angle).
- *
- * <p>The inverse kinematics (converting from a desired chassis velocity to
- * individual module states) uses the relative locations of the modules with
- * respect to the center of rotation. The center of rotation for inverse
- * kinematics is also variable. This means that you can set your set your center
- * of rotation in a corner of the robot to perform special evasion maneuvers.
- *
- * <p>Forward kinematics (converting an array of module states into the overall
- * chassis motion) is performs the exact opposite of what inverse kinematics
- * does. Since this is an overdetermined system (more equations than variables),
- * we use a least-squares approximation.
- *
- * <p>The inverse kinematics: [moduleStates] = [moduleLocations] * [chassisSpeeds]
- * We take the Moore-Penrose pseudoinverse of [moduleLocations] and then
- * multiply by [moduleStates] to get our chassis speeds.
- *
- * <p>Forward kinematics is also used for odometry -- determining the position of
- * the robot on the field using encoders and a gyro.
- */
-public class SwerveDriveKinematics {
- private final SimpleMatrix m_inverseKinematics;
- private final SimpleMatrix m_forwardKinematics;
-
- private final int m_numModules;
- private final Translation2d[] m_modules;
- private Translation2d m_prevCoR = new Translation2d();
-
- /**
- * Constructs a swerve drive kinematics object. This takes in a variable
- * number of wheel locations as Translation2ds. The order in which you pass in
- * the wheel locations is the same order that you will receive the module
- * states when performing inverse kinematics. It is also expected that you
- * pass in the module states in the same order when calling the forward
- * kinematics methods.
- *
- * @param wheelsMeters The locations of the wheels relative to the physical center
- * of the robot.
- */
- public SwerveDriveKinematics(Translation2d... wheelsMeters) {
- if (wheelsMeters.length < 2) {
- throw new IllegalArgumentException("A swerve drive requires at least two modules");
- }
- m_numModules = wheelsMeters.length;
- m_modules = Arrays.copyOf(wheelsMeters, m_numModules);
- m_inverseKinematics = new SimpleMatrix(m_numModules * 2, 3);
-
- for (int i = 0; i < m_numModules; i++) {
- m_inverseKinematics.setRow(i * 2 + 0, 0, /* Start Data */ 1, 0, -m_modules[i].getY());
- m_inverseKinematics.setRow(i * 2 + 1, 0, /* Start Data */ 0, 1, +m_modules[i].getX());
- }
- m_forwardKinematics = m_inverseKinematics.pseudoInverse();
-
- MathSharedStore.reportUsage(MathUsageId.kKinematics_SwerveDrive, 1);
- }
-
- /**
- * Performs inverse kinematics to return the module states from a desired
- * chassis velocity. This method is often used to convert joystick values into
- * module speeds and angles.
- *
- * <p>This function also supports variable centers of rotation. During normal
- * operations, the center of rotation is usually the same as the physical
- * center of the robot; therefore, the argument is defaulted to that use case.
- * However, if you wish to change the center of rotation for evasive
- * maneuvers, vision alignment, or for any other use case, you can do so.
- *
- * @param chassisSpeeds The desired chassis speed.
- * @param centerOfRotationMeters The center of rotation. For example, if you set the
- * center of rotation at one corner of the robot and provide
- * a chassis speed that only has a dtheta component, the robot
- * will rotate around that corner.
- * @return An array containing the module states. Use caution because these
- * module states are not normalized. Sometimes, a user input may cause one of
- * the module speeds to go above the attainable max velocity. Use the
- * {@link #normalizeWheelSpeeds(SwerveModuleState[], double) normalizeWheelSpeeds}
- * function to rectify this issue.
- */
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
- public SwerveModuleState[] toSwerveModuleStates(ChassisSpeeds chassisSpeeds,
- Translation2d centerOfRotationMeters) {
- if (!centerOfRotationMeters.equals(m_prevCoR)) {
- for (int i = 0; i < m_numModules; i++) {
- m_inverseKinematics.setRow(i * 2 + 0, 0, /* Start Data */ 1, 0,
- -m_modules[i].getY() + centerOfRotationMeters.getY());
- m_inverseKinematics.setRow(i * 2 + 1, 0, /* Start Data */ 0, 1,
- +m_modules[i].getX() - centerOfRotationMeters.getX());
- }
- m_prevCoR = centerOfRotationMeters;
- }
-
- var chassisSpeedsVector = new SimpleMatrix(3, 1);
- chassisSpeedsVector.setColumn(0, 0,
- chassisSpeeds.vxMetersPerSecond, chassisSpeeds.vyMetersPerSecond,
- chassisSpeeds.omegaRadiansPerSecond);
-
- var moduleStatesMatrix = m_inverseKinematics.mult(chassisSpeedsVector);
- SwerveModuleState[] moduleStates = new SwerveModuleState[m_numModules];
-
- for (int i = 0; i < m_numModules; i++) {
- double x = moduleStatesMatrix.get(i * 2, 0);
- double y = moduleStatesMatrix.get(i * 2 + 1, 0);
-
- double speed = Math.hypot(x, y);
- Rotation2d angle = new Rotation2d(x, y);
-
- moduleStates[i] = new SwerveModuleState(speed, angle);
- }
-
- return moduleStates;
- }
-
- /**
- * Performs inverse kinematics. See {@link #toSwerveModuleStates(ChassisSpeeds, Translation2d)}
- * toSwerveModuleStates for more information.
- *
- * @param chassisSpeeds The desired chassis speed.
- * @return An array containing the module states.
- */
- public SwerveModuleState[] toSwerveModuleStates(ChassisSpeeds chassisSpeeds) {
- return toSwerveModuleStates(chassisSpeeds, new Translation2d());
- }
-
- /**
- * Performs forward kinematics to return the resulting chassis state from the
- * given module states. This method is often used for odometry -- determining
- * the robot's position on the field using data from the real-world speed and
- * angle of each module on the robot.
- *
- * @param wheelStates The state of the modules (as a SwerveModuleState type)
- * as measured from respective encoders and gyros. The order of the swerve
- * module states should be same as passed into the constructor of this class.
- * @return The resulting chassis speed.
- */
- public ChassisSpeeds toChassisSpeeds(SwerveModuleState... wheelStates) {
- if (wheelStates.length != m_numModules) {
- throw new IllegalArgumentException(
- "Number of modules is not consistent with number of wheel locations provided in "
- + "constructor"
- );
- }
- var moduleStatesMatrix = new SimpleMatrix(m_numModules * 2, 1);
-
- for (int i = 0; i < m_numModules; i++) {
- var module = wheelStates[i];
- moduleStatesMatrix.set(i * 2, 0, module.speedMetersPerSecond * module.angle.getCos());
- moduleStatesMatrix.set(i * 2 + 1, module.speedMetersPerSecond * module.angle.getSin());
- }
-
- var chassisSpeedsVector = m_forwardKinematics.mult(moduleStatesMatrix);
- return new ChassisSpeeds(chassisSpeedsVector.get(0, 0), chassisSpeedsVector.get(1, 0),
- chassisSpeedsVector.get(2, 0));
-
- }
-
- /**
- * Normalizes the wheel speeds using some max attainable speed. Sometimes,
- * after inverse kinematics, the requested speed from a/several modules may be
- * above the max attainable speed for the driving motor on that module. To fix
- * this issue, one can "normalize" all the wheel speeds to make sure that all
- * requested module speeds are below the absolute threshold, while maintaining
- * the ratio of speeds between modules.
- *
- * @param moduleStates Reference to array of module states. The array will be
- * mutated with the normalized speeds!
- * @param attainableMaxSpeedMetersPerSecond The absolute max speed that a module can reach.
- */
- public static void normalizeWheelSpeeds(SwerveModuleState[] moduleStates,
- double attainableMaxSpeedMetersPerSecond) {
- double realMaxSpeed = Collections.max(Arrays.asList(moduleStates)).speedMetersPerSecond;
- if (realMaxSpeed > attainableMaxSpeedMetersPerSecond) {
- for (SwerveModuleState moduleState : moduleStates) {
- moduleState.speedMetersPerSecond = moduleState.speedMetersPerSecond / realMaxSpeed
- * attainableMaxSpeedMetersPerSecond;
- }
- }
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveOdometry.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveOdometry.java
deleted file mode 100644
index 5b1f975..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveOdometry.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import edu.wpi.first.math.MathSharedStore;
-import edu.wpi.first.math.MathUsageId;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Twist2d;
-import edu.wpi.first.wpiutil.WPIUtilJNI;
-
-/**
- * Class for swerve drive odometry. Odometry allows you to track the robot's
- * position on the field over a course of a match using readings from your
- * swerve drive encoders and swerve azimuth encoders.
- *
- * <p>Teams can use odometry during the autonomous period for complex tasks like
- * path following. Furthermore, odometry can be used for latency compensation
- * when using computer-vision systems.
- */
-public class SwerveDriveOdometry {
- private final SwerveDriveKinematics m_kinematics;
- private Pose2d m_poseMeters;
- private double m_prevTimeSeconds = -1;
-
- private Rotation2d m_gyroOffset;
- private Rotation2d m_previousAngle;
-
- /**
- * Constructs a SwerveDriveOdometry object.
- *
- * @param kinematics The swerve drive kinematics for your drivetrain.
- * @param gyroAngle The angle reported by the gyroscope.
- * @param initialPose The starting position of the robot on the field.
- */
- public SwerveDriveOdometry(SwerveDriveKinematics kinematics, Rotation2d gyroAngle,
- Pose2d initialPose) {
- m_kinematics = kinematics;
- m_poseMeters = initialPose;
- m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
- m_previousAngle = initialPose.getRotation();
- MathSharedStore.reportUsage(MathUsageId.kOdometry_SwerveDrive, 1);
- }
-
- /**
- * Constructs a SwerveDriveOdometry object with the default pose at the origin.
- *
- * @param kinematics The swerve drive kinematics for your drivetrain.
- * @param gyroAngle The angle reported by the gyroscope.
- */
- public SwerveDriveOdometry(SwerveDriveKinematics kinematics, Rotation2d gyroAngle) {
- this(kinematics, gyroAngle, new Pose2d());
- }
-
- /**
- * Resets the robot's position on the field.
- *
- * <p>The gyroscope angle does not need to be reset here on the user's robot code.
- * The library automatically takes care of offsetting the gyro angle.
- *
- * @param pose The position on the field that your robot is at.
- * @param gyroAngle The angle reported by the gyroscope.
- */
- public void resetPosition(Pose2d pose, Rotation2d gyroAngle) {
- m_poseMeters = pose;
- m_previousAngle = pose.getRotation();
- m_gyroOffset = m_poseMeters.getRotation().minus(gyroAngle);
- }
-
- /**
- * Returns the position of the robot on the field.
- *
- * @return The pose of the robot (x and y are in meters).
- */
- public Pose2d getPoseMeters() {
- return m_poseMeters;
- }
-
- /**
- * Updates the robot's position on the field using forward kinematics and
- * integration of the pose over time. This method takes in the current time as
- * a parameter to calculate period (difference between two timestamps). The
- * period is used to calculate the change in distance from a velocity. This
- * also takes in an angle parameter which is used instead of the
- * angular rate that is calculated from forward kinematics.
- *
- * @param currentTimeSeconds The current time in seconds.
- * @param gyroAngle The angle reported by the gyroscope.
- * @param moduleStates The current state of all swerve modules. Please provide
- * the states in the same order in which you instantiated your
- * SwerveDriveKinematics.
- * @return The new pose of the robot.
- */
- public Pose2d updateWithTime(double currentTimeSeconds, Rotation2d gyroAngle,
- SwerveModuleState... moduleStates) {
- double period = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : 0.0;
- m_prevTimeSeconds = currentTimeSeconds;
-
- var angle = gyroAngle.plus(m_gyroOffset);
-
- var chassisState = m_kinematics.toChassisSpeeds(moduleStates);
- var newPose = m_poseMeters.exp(
- new Twist2d(chassisState.vxMetersPerSecond * period,
- chassisState.vyMetersPerSecond * period,
- angle.minus(m_previousAngle).getRadians()));
-
- m_previousAngle = angle;
- m_poseMeters = new Pose2d(newPose.getTranslation(), angle);
-
- return m_poseMeters;
- }
-
- /**
- * Updates the robot's position on the field using forward kinematics and
- * integration of the pose over time. This method automatically calculates the
- * current time to calculate period (difference between two timestamps). The
- * period is used to calculate the change in distance from a velocity. This
- * also takes in an angle parameter which is used instead of the angular
- * rate that is calculated from forward kinematics.
- *
- * @param gyroAngle The angle reported by the gyroscope.
- * @param moduleStates The current state of all swerve modules. Please provide
- * the states in the same order in which you instantiated your
- * SwerveDriveKinematics.
- * @return The new pose of the robot.
- */
- public Pose2d update(Rotation2d gyroAngle, SwerveModuleState... moduleStates) {
- return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, moduleStates);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveModuleState.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveModuleState.java
deleted file mode 100644
index f9570eb..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/kinematics/SwerveModuleState.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-
-/**
- * Represents the state of one swerve module.
- */
-@SuppressWarnings("MemberName")
-public class SwerveModuleState implements Comparable<SwerveModuleState> {
-
- /**
- * Speed of the wheel of the module.
- */
- public double speedMetersPerSecond;
-
- /**
- * Angle of the module.
- */
- public Rotation2d angle = Rotation2d.fromDegrees(0);
-
- /**
- * Constructs a SwerveModuleState with zeros for speed and angle.
- */
- public SwerveModuleState() {
- }
-
- /**
- * Constructs a SwerveModuleState.
- *
- * @param speedMetersPerSecond The speed of the wheel of the module.
- * @param angle The angle of the module.
- */
- public SwerveModuleState(double speedMetersPerSecond, Rotation2d angle) {
- this.speedMetersPerSecond = speedMetersPerSecond;
- this.angle = angle;
- }
-
- /**
- * Compares two swerve module states. One swerve module is "greater" than the other if its speed
- * is higher than the other.
- *
- * @param o The other swerve module.
- * @return 1 if this is greater, 0 if both are equal, -1 if other is greater.
- */
- @Override
- @SuppressWarnings("ParameterName")
- public int compareTo(SwerveModuleState o) {
- return Double.compare(this.speedMetersPerSecond, o.speedMetersPerSecond);
- }
-
- @Override
- public String toString() {
- return String.format("SwerveModuleState(Speed: %.2f m/s, Angle: %s)", speedMetersPerSecond,
- angle);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/math/Discretization.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/math/Discretization.java
deleted file mode 100644
index ad9bf27..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/math/Discretization.java
+++ /dev/null
@@ -1,179 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.math;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.Pair;
-
-@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
-public final class Discretization {
- private Discretization() {
- // Utility class
- }
-
- /**
- * Discretizes the given continuous A matrix.
- *
- * @param <States> Num representing the number of states.
- * @param contA Continuous system matrix.
- * @param dtSeconds Discretization timestep.
- * @return the discrete matrix system.
- */
- public static <States extends Num> Matrix<States, States> discretizeA(
- Matrix<States, States> contA, double dtSeconds) {
- return contA.times(dtSeconds).exp();
- }
-
- /**
- * Discretizes the given continuous A and B matrices.
- *
- * <p>Rather than solving a (States + Inputs) x (States + Inputs) matrix
- * exponential like in DiscretizeAB(), we take advantage of the structure of the
- * block matrix of A and B.
- *
- * <p>1) The exponential of A*t, which is only N x N, is relatively cheap.
- * 2) The upper-right quarter of the (States + Inputs) x (States + Inputs)
- * matrix, which we can approximate using a taylor series to several terms
- * and still be substantially cheaper than taking the big exponential.
- *
- * @param states Nat representing the states of the system.
- * @param contA Continuous system matrix.
- * @param contB Continuous input matrix.
- * @param dtseconds Discretization timestep.
- */
- public static <States extends Num, Inputs extends Num> Pair<Matrix<States, States>,
- Matrix<States, Inputs>>
- discretizeABTaylor(Nat<States> states,
- Matrix<States, States> contA,
- Matrix<States, Inputs> contB,
- double dtseconds) {
- Matrix<States, States> lastTerm = Matrix.eye(states);
- double lastCoeff = dtseconds;
-
- var phi12 = lastTerm.times(lastCoeff);
-
- // i = 6 i.e. 5th order should be enough precision
- for (int i = 2; i < 6; ++i) {
- lastTerm = contA.times(lastTerm);
- lastCoeff *= dtseconds / ((double) i);
-
- phi12 = phi12.plus(lastTerm.times(lastCoeff));
- }
-
- var discB = phi12.times(contB);
-
- var discA = discretizeA(contA, dtseconds);
-
- return Pair.of(discA, discB);
- }
-
- /**
- * Discretizes the given continuous A and Q matrices.
- *
- * <p>Rather than solving a 2N x 2N matrix exponential like in DiscretizeQ() (which
- * is expensive), we take advantage of the structure of the block matrix of A
- * and Q.
- *
- * <p>The exponential of A*t, which is only N x N, is relatively cheap.
- * 2) The upper-right quarter of the 2N x 2N matrix, which we can approximate
- * using a taylor series to several terms and still be substantially cheaper
- * than taking the big exponential.
- *
- * @param <States> Nat representing the number of states.
- * @param contA Continuous system matrix.
- * @param contQ Continuous process noise covariance matrix.
- * @param dtSeconds Discretization timestep.
- * @return a pair representing the discrete system matrix and process noise covariance matrix.
- */
- @SuppressWarnings("LocalVariableName")
- public static <States extends Num> Pair<Matrix<States, States>,
- Matrix<States, States>> discretizeAQTaylor(Matrix<States, States> contA,
- Matrix<States, States> contQ,
- double dtSeconds) {
- Matrix<States, States> Q = (contQ.plus(contQ.transpose())).div(2.0);
-
-
- Matrix<States, States> lastTerm = Q.copy();
- double lastCoeff = dtSeconds;
-
- // A^T^n
- Matrix<States, States> Atn = contA.transpose();
- Matrix<States, States> phi12 = lastTerm.times(lastCoeff);
-
- // i = 6 i.e. 6th order should be enough precision
- for (int i = 2; i < 6; ++i) {
- lastTerm = contA.times(-1).times(lastTerm).plus(Q.times(Atn));
- lastCoeff *= dtSeconds / ((double) i);
-
- phi12 = phi12.plus(lastTerm.times(lastCoeff));
-
- Atn = Atn.times(contA.transpose());
- }
-
- var discA = discretizeA(contA, dtSeconds);
- Q = discA.times(phi12);
-
- // Make Q symmetric if it isn't already
- var discQ = Q.plus(Q.transpose()).div(2.0);
-
- return new Pair<>(discA, discQ);
- }
-
- /**
- * Returns a discretized version of the provided continuous measurement noise
- * covariance matrix. Note that dt=0.0 divides R by zero.
- *
- * @param <O> Nat representing the number of outputs.
- * @param R Continuous measurement noise covariance matrix.
- * @param dtSeconds Discretization timestep.
- * @return Discretized version of the provided continuous measurement noise covariance matrix.
- */
- public static <O extends Num> Matrix<O, O> discretizeR(Matrix<O, O> R, double dtSeconds) {
- return R.div(dtSeconds);
- }
-
- /**
- * Discretizes the given continuous A and B matrices.
- *
- * @param <States> Nat representing the states of the system.
- * @param <Inputs> Nat representing the inputs to the system.
- * @param contA Continuous system matrix.
- * @param contB Continuous input matrix.
- * @param dtSeconds Discretization timestep.
- * @return a Pair representing discA and diskB.
- */
- @SuppressWarnings("LocalVariableName")
- public static <States extends Num, Inputs extends Num> Pair<Matrix<States, States>,
- Matrix<States, Inputs>> discretizeAB(
- Matrix<States, States> contA,
- Matrix<States, Inputs> contB,
- double dtSeconds) {
- var scaledA = contA.times(dtSeconds);
- var scaledB = contB.times(dtSeconds);
-
- var contSize = contB.getNumRows() + contB.getNumCols();
- var Mcont = new Matrix<>(new SimpleMatrix(contSize, contSize));
- Mcont.assignBlock(0, 0, scaledA);
- Mcont.assignBlock(0, scaledA.getNumCols(), scaledB);
- var Mdisc = Mcont.exp();
-
- var discA = new Matrix<States, States>(new SimpleMatrix(contB.getNumRows(),
- contB.getNumRows()));
- var discB = new Matrix<States, Inputs>(new SimpleMatrix(contB.getNumRows(),
- contB.getNumCols()));
-
- discA.extractFrom(0, 0, Mdisc);
- discB.extractFrom(0, contB.getNumRows(), Mdisc);
-
- return new Pair<>(discA, discB);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/math/StateSpaceUtil.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/math/StateSpaceUtil.java
deleted file mode 100644
index 02ca7c1..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/math/StateSpaceUtil.java
+++ /dev/null
@@ -1,180 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.math;
-
-import java.util.Random;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.math.WPIMathJNI;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpiutil.math.MathUtil;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N3;
-
-@SuppressWarnings("ParameterName")
-public final class StateSpaceUtil {
- private StateSpaceUtil() {
- // Utility class
- }
-
- /**
- * Creates a covariance matrix from the given vector for use with Kalman
- * filters.
- *
- * <p>Each element is squared and placed on the covariance matrix diagonal.
- *
- * @param <States> Num representing the states of the system.
- * @param states A Nat representing the states of the system.
- * @param stdDevs For a Q matrix, its elements are the standard deviations of
- * each state from how the model behaves. For an R matrix, its
- * elements are the standard deviations for each output
- * measurement.
- * @return Process noise or measurement noise covariance matrix.
- */
- @SuppressWarnings("MethodTypeParameterName")
- public static <States extends Num> Matrix<States, States> makeCovarianceMatrix(
- Nat<States> states, Matrix<States, N1> stdDevs
- ) {
- var result = new Matrix<>(states, states);
- for (int i = 0; i < states.getNum(); i++) {
- result.set(i, i, Math.pow(stdDevs.get(i, 0), 2));
- }
- return result;
- }
-
- /**
- * Creates a vector of normally distributed white noise with the given noise
- * intensities for each element.
- *
- * @param <N> Num representing the dimensionality of the noise vector to create.
- * @param stdDevs A matrix whose elements are the standard deviations of each
- * element of the noise vector.
- * @return White noise vector.
- */
- public static <N extends Num> Matrix<N, N1> makeWhiteNoiseVector(
- Matrix<N, N1> stdDevs
- ) {
- var rand = new Random();
-
- Matrix<N, N1> result = new Matrix<>(new SimpleMatrix(stdDevs.getNumRows(), 1));
- for (int i = 0; i < stdDevs.getNumRows(); i++) {
- result.set(i, 0, rand.nextGaussian() * stdDevs.get(i, 0));
- }
- return result;
- }
-
- /**
- * Creates a cost matrix from the given vector for use with LQR.
- *
- * <p>The cost matrix is constructed using Bryson's rule. The inverse square of
- * each element in the input is taken and placed on the cost matrix diagonal.
- *
- * @param <States> Nat representing the states of the system.
- * @param costs An array. For a Q matrix, its elements are the maximum allowed
- * excursions of the states from the reference. For an R matrix,
- * its elements are the maximum allowed excursions of the control
- * inputs from no actuation.
- * @return State excursion or control effort cost matrix.
- */
- @SuppressWarnings("MethodTypeParameterName")
- public static <States extends Num> Matrix<States, States>
- makeCostMatrix(Matrix<States, N1> costs) {
- Matrix<States, States> result =
- new Matrix<>(new SimpleMatrix(costs.getNumRows(), costs.getNumRows()));
- result.fill(0.0);
-
- for (int i = 0; i < costs.getNumRows(); i++) {
- result.set(i, i, 1.0 / (Math.pow(costs.get(i, 0), 2)));
- }
-
- return result;
- }
-
- /**
- * Returns true if (A, B) is a stabilizable pair.
- *
- * <p>(A,B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
- * any, have absolute values less than one, where an eigenvalue is
- * uncontrollable if rank(lambda * I - A, B) %3C n where n is number of states.
- *
- * @param <States> Num representing the size of A.
- * @param <Inputs> Num representing the columns of B.
- * @param A System matrix.
- * @param B Input matrix.
- * @return If the system is stabilizable.
- */
- @SuppressWarnings("MethodTypeParameterName")
- public static <States extends Num, Inputs extends Num> boolean isStabilizable(
- Matrix<States, States> A, Matrix<States, Inputs> B) {
- return WPIMathJNI.isStabilizable(A.getNumRows(), B.getNumCols(),
- A.getData(), B.getData());
- }
-
- /**
- * Convert a {@link Pose2d} to a vector of [x, y, theta], where theta is in radians.
- *
- * @param pose A pose to convert to a vector.
- * @return The given pose in vector form, with the third element, theta, in radians.
- */
- public static Matrix<N3, N1> poseToVector(Pose2d pose) {
- return VecBuilder.fill(
- pose.getX(),
- pose.getY(),
- pose.getRotation().getRadians()
- );
- }
-
- /**
- * Clamp the input u to the min and max.
- *
- * @param u The input to clamp.
- * @param umin The minimum input magnitude.
- * @param umax The maximum input magnitude.
- * @param <I> The number of inputs.
- * @return The clamped input.
- */
- @SuppressWarnings({"ParameterName", "LocalVariableName"})
- public static <I extends Num> Matrix<I, N1> clampInputMaxMagnitude(Matrix<I, N1> u,
- Matrix<I, N1> umin,
- Matrix<I, N1> umax) {
- var result = new Matrix<I, N1>(new SimpleMatrix(u.getNumRows(), 1));
- for (int i = 0; i < u.getNumRows(); i++) {
- result.set(i, 0, MathUtil.clamp(
- u.get(i, 0),
- umin.get(i, 0),
- umax.get(i, 0)));
- }
- return result;
- }
-
- /**
- * Normalize all inputs if any excedes the maximum magnitude. Useful for systems such as
- * differential drivetrains.
- *
- * @param u The input vector.
- * @param maxMagnitude The maximum magnitude any input can have.
- * @param <I> The number of inputs.
- * @return The normalizedInput
- */
- public static <I extends Num> Matrix<I, N1> normalizeInputVector(Matrix<I, N1> u,
- double maxMagnitude) {
- double maxValue = u.maxAbs();
- boolean isCapped = maxValue > maxMagnitude;
-
- if (isCapped) {
- return u.times(maxMagnitude / maxValue);
- }
- return u;
- }
-
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/CubicHermiteSpline.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/CubicHermiteSpline.java
deleted file mode 100644
index f387fc0..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/CubicHermiteSpline.java
+++ /dev/null
@@ -1,115 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.spline;
-
-import org.ejml.simple.SimpleMatrix;
-
-public class CubicHermiteSpline extends Spline {
- private static SimpleMatrix hermiteBasis;
- private final SimpleMatrix m_coefficients;
-
- /**
- * Constructs a cubic hermite spline with the specified control vectors. Each
- * control vector contains info about the location of the point and its first
- * derivative.
- *
- * @param xInitialControlVector The control vector for the initial point in
- * the x dimension.
- * @param xFinalControlVector The control vector for the final point in
- * the x dimension.
- * @param yInitialControlVector The control vector for the initial point in
- * the y dimension.
- * @param yFinalControlVector The control vector for the final point in
- * the y dimension.
- */
- @SuppressWarnings("ParameterName")
- public CubicHermiteSpline(double[] xInitialControlVector, double[] xFinalControlVector,
- double[] yInitialControlVector, double[] yFinalControlVector) {
- super(3);
-
- // Populate the coefficients for the actual spline equations.
- // Row 0 is x coefficients
- // Row 1 is y coefficients
- final var hermite = makeHermiteBasis();
- final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
- final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
-
- final var xCoeffs = (hermite.mult(x)).transpose();
- final var yCoeffs = (hermite.mult(y)).transpose();
-
- m_coefficients = new SimpleMatrix(6, 4);
-
- for (int i = 0; i < 4; i++) {
- m_coefficients.set(0, i, xCoeffs.get(0, i));
- m_coefficients.set(1, i, yCoeffs.get(0, i));
-
- // Populate Row 2 and Row 3 with the derivatives of the equations above.
- // Then populate row 4 and 5 with the second derivatives.
- // Here, we are multiplying by (3 - i) to manually take the derivative. The
- // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
- // coefficient of the derivative, we can use the power rule and multiply
- // the existing coefficient by its power.
- m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
- m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
- }
-
- for (int i = 0; i < 3; i++) {
- // Here, we are multiplying by (2 - i) to manually take the derivative. The
- // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
- // coefficient of the derivative, we can use the power rule and multiply
- // the existing coefficient by its power.
- m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
- m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
- }
-
- }
-
- /**
- * Returns the coefficients matrix.
- *
- * @return The coefficients matrix.
- */
- @Override
- protected SimpleMatrix getCoefficients() {
- return m_coefficients;
- }
-
- /**
- * Returns the hermite basis matrix for cubic hermite spline interpolation.
- *
- * @return The hermite basis matrix for cubic hermite spline interpolation.
- */
- private SimpleMatrix makeHermiteBasis() {
- if (hermiteBasis == null) {
- hermiteBasis = new SimpleMatrix(4, 4, true, new double[]{
- +2.0, +1.0, -2.0, +1.0,
- -3.0, -2.0, +3.0, -1.0,
- +0.0, +1.0, +0.0, +0.0,
- +1.0, +0.0, +0.0, +0.0
- });
- }
- return hermiteBasis;
- }
-
- /**
- * Returns the control vector for each dimension as a matrix from the
- * user-provided arrays in the constructor.
- *
- * @param initialVector The control vector for the initial point.
- * @param finalVector The control vector for the final point.
- * @return The control vector matrix for a dimension.
- */
- private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
- if (initialVector.length != 2 || finalVector.length != 2) {
- throw new IllegalArgumentException("Size of vectors must be 2");
- }
- return new SimpleMatrix(4, 1, true, new double[]{
- initialVector[0], initialVector[1],
- finalVector[0], finalVector[1]});
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/PoseWithCurvature.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/PoseWithCurvature.java
deleted file mode 100644
index ed8562d..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/PoseWithCurvature.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.spline;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-
-/**
- * Represents a pair of a pose and a curvature.
- */
-@SuppressWarnings("MemberName")
-public class PoseWithCurvature {
- // Represents the pose.
- public Pose2d poseMeters;
-
- // Represents the curvature.
- public double curvatureRadPerMeter;
-
- /**
- * Constructs a PoseWithCurvature.
- *
- * @param poseMeters The pose.
- * @param curvatureRadPerMeter The curvature.
- */
- public PoseWithCurvature(Pose2d poseMeters, double curvatureRadPerMeter) {
- this.poseMeters = poseMeters;
- this.curvatureRadPerMeter = curvatureRadPerMeter;
- }
-
- /**
- * Constructs a PoseWithCurvature with default values.
- */
- public PoseWithCurvature() {
- poseMeters = new Pose2d();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/QuinticHermiteSpline.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/QuinticHermiteSpline.java
deleted file mode 100644
index 6073f62..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/spline/QuinticHermiteSpline.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.spline;
-
-import org.ejml.simple.SimpleMatrix;
-
-public class QuinticHermiteSpline extends Spline {
- private static SimpleMatrix hermiteBasis;
- private final SimpleMatrix m_coefficients;
-
- /**
- * Constructs a quintic hermite spline with the specified control vectors.
- * Each control vector contains into about the location of the point, its
- * first derivative, and its second derivative.
- *
- * @param xInitialControlVector The control vector for the initial point in
- * the x dimension.
- * @param xFinalControlVector The control vector for the final point in
- * the x dimension.
- * @param yInitialControlVector The control vector for the initial point in
- * the y dimension.
- * @param yFinalControlVector The control vector for the final point in
- * the y dimension.
- */
- @SuppressWarnings("ParameterName")
- public QuinticHermiteSpline(double[] xInitialControlVector, double[] xFinalControlVector,
- double[] yInitialControlVector, double[] yFinalControlVector) {
- super(5);
-
- // Populate the coefficients for the actual spline equations.
- // Row 0 is x coefficients
- // Row 1 is y coefficients
- final var hermite = makeHermiteBasis();
- final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
- final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
-
- final var xCoeffs = (hermite.mult(x)).transpose();
- final var yCoeffs = (hermite.mult(y)).transpose();
-
- m_coefficients = new SimpleMatrix(6, 6);
-
- for (int i = 0; i < 6; i++) {
- m_coefficients.set(0, i, xCoeffs.get(0, i));
- m_coefficients.set(1, i, yCoeffs.get(0, i));
- }
- for (int i = 0; i < 6; i++) {
- // Populate Row 2 and Row 3 with the derivatives of the equations above.
- // Here, we are multiplying by (5 - i) to manually take the derivative. The
- // power of the term in index 0 is 5, index 1 is 4 and so on. To find the
- // coefficient of the derivative, we can use the power rule and multiply
- // the existing coefficient by its power.
- m_coefficients.set(2, i, m_coefficients.get(0, i) * (5 - i));
- m_coefficients.set(3, i, m_coefficients.get(1, i) * (5 - i));
- }
- for (int i = 0; i < 5; i++) {
- // Then populate row 4 and 5 with the second derivatives.
- // Here, we are multiplying by (4 - i) to manually take the derivative. The
- // power of the term in index 0 is 4, index 1 is 3 and so on. To find the
- // coefficient of the derivative, we can use the power rule and multiply
- // the existing coefficient by its power.
- m_coefficients.set(4, i, m_coefficients.get(2, i) * (4 - i));
- m_coefficients.set(5, i, m_coefficients.get(3, i) * (4 - i));
- }
- }
-
- /**
- * Returns the coefficients matrix.
- *
- * @return The coefficients matrix.
- */
- @Override
- protected SimpleMatrix getCoefficients() {
- return m_coefficients;
- }
-
- /**
- * Returns the hermite basis matrix for quintic hermite spline interpolation.
- *
- * @return The hermite basis matrix for quintic hermite spline interpolation.
- */
- private SimpleMatrix makeHermiteBasis() {
- if (hermiteBasis == null) {
- hermiteBasis = new SimpleMatrix(6, 6, true, new double[]{
- -06.0, -03.0, -00.5, +06.0, -03.0, +00.5,
- +15.0, +08.0, +01.5, -15.0, +07.0, +01.0,
- -10.0, -06.0, -01.5, +10.0, -04.0, +00.5,
- +00.0, +00.0, +00.5, +00.0, +00.0, +00.0,
- +00.0, +01.0, +00.0, +00.0, +00.0, +00.0,
- +01.0, +00.0, +00.0, +00.0, +00.0, +00.0
- });
- }
- return hermiteBasis;
- }
-
- /**
- * Returns the control vector for each dimension as a matrix from the
- * user-provided arrays in the constructor.
- *
- * @param initialVector The control vector for the initial point.
- * @param finalVector The control vector for the final point.
- * @return The control vector matrix for a dimension.
- */
- private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
- if (initialVector.length != 3 || finalVector.length != 3) {
- throw new IllegalArgumentException("Size of vectors must be 3");
- }
- return new SimpleMatrix(6, 1, true, new double[]{
- initialVector[0], initialVector[1], initialVector[2],
- finalVector[0], finalVector[1], finalVector[2]});
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/LinearSystem.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/system/LinearSystem.java
deleted file mode 100644
index 4a90caa..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/LinearSystem.java
+++ /dev/null
@@ -1,182 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system;
-
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-@SuppressWarnings("ClassTypeParameterName")
-public class LinearSystem<States extends Num, Inputs extends Num,
- Outputs extends Num> {
-
- /**
- * Continuous system matrix.
- */
- @SuppressWarnings("MemberName")
- private final Matrix<States, States> m_A;
-
- /**
- * Continuous input matrix.
- */
- @SuppressWarnings("MemberName")
- private final Matrix<States, Inputs> m_B;
-
- /**
- * Output matrix.
- */
- @SuppressWarnings("MemberName")
- private final Matrix<Outputs, States> m_C;
-
- /**
- * Feedthrough matrix.
- */
- @SuppressWarnings("MemberName")
- private final Matrix<Outputs, Inputs> m_D;
-
- /**
- * Construct a new LinearSystem from the four system matrices.
- *
- * @param a The system matrix A.
- * @param b The input matrix B.
- * @param c The output matrix C.
- * @param d The feedthrough matrix D.
- */
- @SuppressWarnings("ParameterName")
- public LinearSystem(Matrix<States, States> a, Matrix<States, Inputs> b,
- Matrix<Outputs, States> c, Matrix<Outputs, Inputs> d) {
-
- this.m_A = a;
- this.m_B = b;
- this.m_C = c;
- this.m_D = d;
- }
-
- /**
- * Returns the system matrix A.
- *
- * @return the system matrix A.
- */
- public Matrix<States, States> getA() {
- return m_A;
- }
-
- /**
- * Returns an element of the system matrix A.
- *
- * @param row Row of A.
- * @param col Column of A.
- * @return the system matrix A at (i, j).
- */
- public double getA(int row, int col) {
- return m_A.get(row, col);
- }
-
- /**
- * Returns the input matrix B.
- *
- * @return the input matrix B.
- */
- public Matrix<States, Inputs> getB() {
- return m_B;
- }
-
- /**
- * Returns an element of the input matrix B.
- *
- * @param row Row of B.
- * @param col Column of B.
- * @return The value of the input matrix B at (i, j).
- */
- public double getB(int row, int col) {
- return m_B.get(row, col);
- }
-
- /**
- * Returns the output matrix C.
- *
- * @return Output matrix C.
- */
- public Matrix<Outputs, States> getC() {
- return m_C;
- }
-
- /**
- * Returns an element of the output matrix C.
- *
- * @param row Row of C.
- * @param col Column of C.
- * @return the double value of C at the given position.
- */
- public double getC(int row, int col) {
- return m_C.get(row, col);
- }
-
- /**
- * Returns the feedthrough matrix D.
- *
- * @return the feedthrough matrix D.
- */
- public Matrix<Outputs, Inputs> getD() {
- return m_D;
- }
-
- /**
- * Returns an element of the feedthrough matrix D.
- *
- * @param row Row of D.
- * @param col Column of D.
- * @return The feedthrough matrix D at (i, j).
- */
- public double getD(int row, int col) {
- return m_D.get(row, col);
- }
-
- /**
- * Computes the new x given the old x and the control input.
- *
- * <p>This is used by state observers directly to run updates based on state
- * estimate.
- *
- * @param x The current state.
- * @param clampedU The control input.
- * @param dtSeconds Timestep for model update.
- * @return the updated x.
- */
- @SuppressWarnings("ParameterName")
- public Matrix<States, N1> calculateX(Matrix<States, N1> x, Matrix<Inputs, N1> clampedU,
- double dtSeconds) {
- var discABpair = Discretization.discretizeAB(m_A, m_B, dtSeconds);
-
- return (discABpair.getFirst().times(x)).plus(discABpair.getSecond().times(clampedU));
- }
-
- /**
- * Computes the new y given the control input.
- *
- * <p>This is used by state observers directly to run updates based on state
- * estimate.
- *
- * @param x The current state.
- * @param clampedU The control input.
- * @return the updated output matrix Y.
- */
- @SuppressWarnings("ParameterName")
- public Matrix<Outputs, N1> calculateY(
- Matrix<States, N1> x,
- Matrix<Inputs, N1> clampedU) {
- return m_C.times(x).plus(m_D.times(clampedU));
- }
-
- @Override
- public String toString() {
- return String.format("Linear System: A\n%s\n\nB:\n%s\n\nC:\n%s\n\nD:\n%s\n", m_A.toString(),
- m_B.toString(), m_C.toString(), m_D.toString());
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/LinearSystemLoop.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/system/LinearSystemLoop.java
deleted file mode 100644
index d44ca62..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/LinearSystemLoop.java
+++ /dev/null
@@ -1,358 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system;
-
-import java.util.function.Function;
-
-import org.ejml.MatrixDimensionException;
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.wpilibj.controller.LinearPlantInversionFeedforward;
-import edu.wpi.first.wpilibj.controller.LinearQuadraticRegulator;
-import edu.wpi.first.wpilibj.estimator.KalmanFilter;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-/**
- * Combines a plant, controller, and observer for controlling a mechanism with
- * full state feedback.
- *
- * <p>For everything in this file, "inputs" and "outputs" are defined from the
- * perspective of the plant. This means U is an input and Y is an output
- * (because you give the plant U (powers) and it gives you back a Y (sensor
- * values). This is the opposite of what they mean from the perspective of the
- * controller (U is an output because that's what goes to the motors and Y is an
- * input because that's what comes back from the sensors).
- *
- * <p>For more on the underlying math, read
- * https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
- */
-@SuppressWarnings("ClassTypeParameterName")
-public class LinearSystemLoop<States extends Num, Inputs extends Num,
- Outputs extends Num> {
-
- private final LinearSystem<States, Inputs, Outputs> m_plant;
- private final LinearQuadraticRegulator<States, Inputs, Outputs> m_controller;
- private final LinearPlantInversionFeedforward<States, Inputs, Outputs> m_feedforward;
- private final KalmanFilter<States, Inputs, Outputs> m_observer;
- private Matrix<States, N1> m_nextR;
- private Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> m_clampFunction;
-
- /**
- * Constructs a state-space loop with the given plant, controller, and
- * observer. By default, the initial reference is all zeros. Users should
- * call reset with the initial system state before enabling the loop. This
- * constructor assumes that the input(s) to this system are voltage.
- *
- * @param plant State-space plant.
- * @param controller State-space controller.
- * @param observer State-space observer.
- * @param maxVoltageVolts The maximum voltage that can be applied. Commonly 12.
- * @param dtSeconds The nominal timestep.
- */
- public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
- LinearQuadraticRegulator<States, Inputs, Outputs> controller,
- KalmanFilter<States, Inputs, Outputs> observer,
- double maxVoltageVolts,
- double dtSeconds) {
- this(plant, controller,
- new LinearPlantInversionFeedforward<>(plant, dtSeconds), observer,
- u -> StateSpaceUtil.normalizeInputVector(u, maxVoltageVolts));
- }
-
- /**
- * Constructs a state-space loop with the given plant, controller, and
- * observer. By default, the initial reference is all zeros. Users should
- * call reset with the initial system state before enabling the loop.
- *
- * @param plant State-space plant.
- * @param controller State-space controller.
- * @param observer State-space observer.
- * @param clampFunction The function used to clamp the input U.
- * @param dtSeconds The nominal timestep.
- */
- public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
- LinearQuadraticRegulator<States, Inputs, Outputs> controller,
- KalmanFilter<States, Inputs, Outputs> observer,
- Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction,
- double dtSeconds) {
- this(plant, controller, new LinearPlantInversionFeedforward<>(plant, dtSeconds),
- observer, clampFunction);
- }
-
- /**
- * Constructs a state-space loop with the given plant, controller, and
- * observer. By default, the initial reference is all zeros. Users should
- * call reset with the initial system state before enabling the loop.
- *
- * @param plant State-space plant.
- * @param controller State-space controller.
- * @param feedforward Plant inversion feedforward.
- * @param observer State-space observer.
- * @param maxVoltageVolts The maximum voltage that can be applied. Assumes that the
- * inputs are voltages.
- */
- public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
- LinearQuadraticRegulator<States, Inputs, Outputs> controller,
- LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
- KalmanFilter<States, Inputs, Outputs> observer,
- double maxVoltageVolts
- ) {
- this(plant, controller, feedforward,
- observer, u -> StateSpaceUtil.normalizeInputVector(u, maxVoltageVolts));
- }
-
- /**
- * Constructs a state-space loop with the given plant, controller, and
- * observer. By default, the initial reference is all zeros. Users should
- * call reset with the initial system state before enabling the loop.
- *
- * @param plant State-space plant.
- * @param controller State-space controller.
- * @param feedforward Plant inversion feedforward.
- * @param observer State-space observer.
- * @param clampFunction The function used to clamp the input U.
- */
- public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
- LinearQuadraticRegulator<States, Inputs, Outputs> controller,
- LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
- KalmanFilter<States, Inputs, Outputs> observer,
- Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
- this.m_plant = plant;
- this.m_controller = controller;
- this.m_feedforward = feedforward;
- this.m_observer = observer;
- this.m_clampFunction = clampFunction;
-
- m_nextR = new Matrix<>(new SimpleMatrix(controller.getK().getNumCols(), 1));
- reset(m_nextR);
- }
-
- /**
- * Returns the observer's state estimate x-hat.
- *
- * @return the observer's state estimate x-hat.
- */
- public Matrix<States, N1> getXHat() {
- return getObserver().getXhat();
- }
-
- /**
- * Returns an element of the observer's state estimate x-hat.
- *
- * @param row Row of x-hat.
- * @return the i-th element of the observer's state estimate x-hat.
- */
- public double getXHat(int row) {
- return getObserver().getXhat(row);
- }
-
- /**
- * Set the initial state estimate x-hat.
- *
- * @param xhat The initial state estimate x-hat.
- */
- public void setXHat(Matrix<States, N1> xhat) {
- getObserver().setXhat(xhat);
- }
-
- /**
- * Set an element of the initial state estimate x-hat.
- *
- * @param row Row of x-hat.
- * @param value Value for element of x-hat.
- */
- public void setXHat(int row, double value) {
- getObserver().setXhat(row, value);
- }
-
- /**
- * Returns an element of the controller's next reference r.
- *
- * @param row Row of r.
- * @return the element i of the controller's next reference r.
- */
- public double getNextR(int row) {
- return getNextR().get(row, 0);
- }
-
- /**
- * Returns the controller's next reference r.
- *
- * @return the controller's next reference r.
- */
- public Matrix<States, N1> getNextR() {
- return m_nextR;
- }
-
- /**
- * Set the next reference r.
- *
- * @param nextR Next reference.
- */
- public void setNextR(Matrix<States, N1> nextR) {
- m_nextR = nextR;
- }
-
- /**
- * Set the next reference r.
- *
- * @param nextR Next reference.
- */
- public void setNextR(double... nextR) {
- if (nextR.length != m_nextR.getNumRows()) {
- throw new MatrixDimensionException(String.format("The next reference does not have the "
- + "correct number of entries! Expected %s, but got %s.",
- m_nextR.getNumRows(),
- nextR.length));
- }
- m_nextR = new Matrix<>(new SimpleMatrix(m_nextR.getNumRows(), 1, true, nextR));
- }
-
- /**
- * Returns the controller's calculated control input u plus the calculated feedforward u_ff.
- *
- * @return the calculated control input u.
- */
- public Matrix<Inputs, N1> getU() {
- return clampInput(m_controller.getU().plus(m_feedforward.getUff()));
- }
-
- /**
- * Returns an element of the controller's calculated control input u.
- *
- * @param row Row of u.
- * @return the calculated control input u at the row i.
- */
- public double getU(int row) {
- return getU().get(row, 0);
- }
-
- /**
- * Return the plant used internally.
- *
- * @return the plant used internally.
- */
- public LinearSystem<States, Inputs, Outputs> getPlant() {
- return m_plant;
- }
-
- /**
- * Return the controller used internally.
- *
- * @return the controller used internally.
- */
- public LinearQuadraticRegulator<States, Inputs, Outputs> getController() {
- return m_controller;
- }
-
- /**
- * Return the feedforward used internally.
- *
- * @return the feedforward used internally.
- */
- public LinearPlantInversionFeedforward<States, Inputs, Outputs> getFeedforward() {
- return m_feedforward;
- }
-
- /**
- * Return the observer used internally.
- *
- * @return the observer used internally.
- */
- public KalmanFilter<States, Inputs, Outputs> getObserver() {
- return m_observer;
- }
-
- /**
- * Zeroes reference r and controller output u. The previous reference
- * of the PlantInversionFeedforward and the initial state estimate of
- * the KalmanFilter are set to the initial state provided.
- *
- * @param initialState The initial state.
- */
- public void reset(Matrix<States, N1> initialState) {
- m_nextR.fill(0.0);
- m_controller.reset();
- m_feedforward.reset(initialState);
- m_observer.setXhat(initialState);
- }
-
- /**
- * Returns difference between reference r and current state x-hat.
- *
- * @return The state error matrix.
- */
- public Matrix<States, N1> getError() {
- return getController().getR().minus(m_observer.getXhat());
- }
-
- /**
- * Returns difference between reference r and current state x-hat.
- *
- * @param index The index of the error matrix to return.
- * @return The error at that index.
- */
- public double getError(int index) {
- return (getController().getR().minus(m_observer.getXhat())).get(index, 0);
- }
-
- /**
- * Get the function used to clamp the input u.
- * @return The clamping function.
- */
- public Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> getClampFunction() {
- return m_clampFunction;
- }
-
- /**
- * Set the clamping function used to clamp inputs.
- */
- public void setClampFunction(Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
- this.m_clampFunction = clampFunction;
- }
-
- /**
- * Correct the state estimate x-hat using the measurements in y.
- *
- * @param y Measurement vector.
- */
- @SuppressWarnings("ParameterName")
- public void correct(Matrix<Outputs, N1> y) {
- getObserver().correct(getU(), y);
- }
-
- /**
- * Sets new controller output, projects model forward, and runs observer
- * prediction.
- *
- * <p>After calling this, the user should send the elements of u to the
- * actuators.
- *
- * @param dtSeconds Timestep for model update.
- */
- @SuppressWarnings("LocalVariableName")
- public void predict(double dtSeconds) {
- var u = clampInput(m_controller.calculate(getObserver().getXhat(), m_nextR)
- .plus(m_feedforward.calculate(m_nextR)));
- getObserver().predict(u, dtSeconds);
- }
-
- /**
- * Clamp the input u to the min and max.
- *
- * @param unclampedU The input to clamp.
- * @return The clamped input.
- */
- public Matrix<Inputs, N1> clampInput(Matrix<Inputs, N1> unclampedU) {
- return m_clampFunction.apply(unclampedU);
- }
-
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/NumericalJacobian.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/system/NumericalJacobian.java
deleted file mode 100644
index a808dec..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/NumericalJacobian.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system;
-
-import java.util.function.BiFunction;
-import java.util.function.Function;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-public final class NumericalJacobian {
- private NumericalJacobian() {
- // Utility Class.
- }
-
- private static final double kEpsilon = 1e-5;
-
- /**
- * Computes the numerical Jacobian with respect to x for f(x).
- *
- * @param <Rows> Number of rows in the result of f(x).
- * @param <States> Num representing the number of rows in the output of f.
- * @param <Cols> Number of columns in the result of f(x).
- * @param rows Number of rows in the result of f(x).
- * @param cols Number of columns in the result of f(x).
- * @param f Vector-valued function from which to compute the Jacobian.
- * @param x Vector argument.
- * @return The numerical Jacobian with respect to x for f(x, u, ...).
- */
- @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public static <Rows extends Num, Cols extends Num, States extends Num> Matrix<Rows, Cols>
- numericalJacobian(
- Nat<Rows> rows,
- Nat<Cols> cols,
- Function<Matrix<Cols, N1>, Matrix<States, N1>> f,
- Matrix<Cols, N1> x
- ) {
- var result = new Matrix<>(rows, cols);
-
- for (int i = 0; i < cols.getNum(); i++) {
- var dxPlus = x.copy();
- var dxMinus = x.copy();
- dxPlus.set(i, 0, dxPlus.get(i, 0) + kEpsilon);
- dxMinus.set(i, 0, dxMinus.get(i, 0) - kEpsilon);
- @SuppressWarnings("LocalVariableName")
- var dF = f.apply(dxPlus).minus(f.apply(dxMinus)).div(2 * kEpsilon);
-
- result.setColumn(i, Matrix.changeBoundsUnchecked(dF));
- }
-
- return result;
- }
-
- /**
- * Returns numerical Jacobian with respect to x for f(x, u, ...).
- *
- * @param <Rows> Number of rows in the result of f(x, u).
- * @param <States> Number of rows in x.
- * @param <Inputs> Number of rows in the second input to f.
- * @param <Outputs> Num representing the rows in the output of f.
- * @param rows Number of rows in the result of f(x, u).
- * @param states Number of rows in x.
- * @param f Vector-valued function from which to compute Jacobian.
- * @param x State vector.
- * @param u Input vector.
- * @return The numerical Jacobian with respect to x for f(x, u, ...).
- */
- @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public static <Rows extends Num, States extends Num, Inputs extends Num, Outputs extends Num>
- Matrix<Rows, States> numericalJacobianX(
- Nat<Rows> rows,
- Nat<States> states,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> f,
- Matrix<States, N1> x,
- Matrix<Inputs, N1> u
- ) {
- return numericalJacobian(rows, states, _x -> f.apply(_x, u), x);
- }
-
- /**
- * Returns the numerical Jacobian with respect to u for f(x, u).
- *
- * @param <States> The states of the system.
- * @param <Inputs> The inputs to the system.
- * @param <Rows> Number of rows in the result of f(x, u).
- * @param rows Number of rows in the result of f(x, u).
- * @param inputs Number of rows in u.
- * @param f Vector-valued function from which to compute the Jacobian.
- * @param x State vector.
- * @param u Input vector.
- * @return the numerical Jacobian with respect to u for f(x, u).
- */
- @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public static <Rows extends Num, States extends Num, Inputs extends Num> Matrix<Rows, Inputs>
- numericalJacobianU(
- Nat<Rows> rows,
- Nat<Inputs> inputs,
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
- Matrix<States, N1> x,
- Matrix<Inputs, N1> u
- ) {
- return numericalJacobian(rows, inputs, _u -> f.apply(x, _u), u);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/RungeKutta.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/system/RungeKutta.java
deleted file mode 100644
index fef5ddf..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/RungeKutta.java
+++ /dev/null
@@ -1,113 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system;
-
-import java.util.function.BiFunction;
-import java.util.function.DoubleFunction;
-import java.util.function.Function;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Num;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-public final class RungeKutta {
- private RungeKutta() {
- // utility Class
- }
-
- /**
- * Performs Runge Kutta integration (4th order).
- *
- * @param f The function to integrate, which takes one argument x.
- * @param x The initial value of x.
- * @param dtSeconds The time over which to integrate.
- * @return the integration of dx/dt = f(x) for dt.
- */
- @SuppressWarnings("ParameterName")
- public static double rungeKutta(
- DoubleFunction<Double> f,
- double x,
- double dtSeconds
- ) {
- final var halfDt = 0.5 * dtSeconds;
- final var k1 = f.apply(x);
- final var k2 = f.apply(x + k1 * halfDt);
- final var k3 = f.apply(x + k2 * halfDt);
- final var k4 = f.apply(x + k3 * dtSeconds);
- return x + dtSeconds / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
- }
-
- /**
- * Performs Runge Kutta integration (4th order).
- *
- * @param f The function to integrate. It must take two arguments x and u.
- * @param x The initial value of x.
- * @param u The value u held constant over the integration period.
- * @param dtSeconds The time over which to integrate.
- * @return The result of Runge Kutta integration (4th order).
- */
- @SuppressWarnings("ParameterName")
- public static double rungeKutta(
- BiFunction<Double, Double, Double> f,
- double x, Double u, double dtSeconds
- ) {
- final var halfDt = 0.5 * dtSeconds;
- final var k1 = f.apply(x, u);
- final var k2 = f.apply(x + k1 * halfDt, u);
- final var k3 = f.apply(x + k2 * halfDt, u);
- final var k4 = f.apply(x + k3 * dtSeconds, u);
- return x + dtSeconds / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
- }
-
- /**
- * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
- *
- * @param <States> A Num representing the states of the system to integrate.
- * @param <Inputs> A Num representing the inputs of the system to integrate.
- * @param f The function to integrate. It must take two arguments x and u.
- * @param x The initial value of x.
- * @param u The value u held constant over the integration period.
- * @param dtSeconds The time over which to integrate.
- * @return the integration of dx/dt = f(x, u) for dt.
- */
- @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public static <States extends Num, Inputs extends Num> Matrix<States, N1> rungeKutta(
- BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
- Matrix<States, N1> x, Matrix<Inputs, N1> u, double dtSeconds) {
-
- final var halfDt = 0.5 * dtSeconds;
- Matrix<States, N1> k1 = f.apply(x, u);
- Matrix<States, N1> k2 = f.apply(x.plus(k1.times(halfDt)), u);
- Matrix<States, N1> k3 = f.apply(x.plus(k2.times(halfDt)), u);
- Matrix<States, N1> k4 = f.apply(x.plus(k3.times(dtSeconds)), u);
- return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(dtSeconds).div(6.0));
- }
-
- /**
- * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
- *
- * @param <States> A Num prepresenting the states of the system.
- * @param f The function to integrate. It must take one argument x.
- * @param x The initial value of x.
- * @param dtSeconds The time over which to integrate.
- * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
- */
- @SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
- public static <States extends Num> Matrix<States, N1> rungeKutta(
- Function<Matrix<States, N1>, Matrix<States, N1>> f,
- Matrix<States, N1> x, double dtSeconds) {
-
- final var halfDt = 0.5 * dtSeconds;
- Matrix<States, N1> k1 = f.apply(x);
- Matrix<States, N1> k2 = f.apply(x.plus(k1.times(halfDt)));
- Matrix<States, N1> k3 = f.apply(x.plus(k2.times(halfDt)));
- Matrix<States, N1> k4 = f.apply(x.plus(k3.times(dtSeconds)));
- return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(dtSeconds).div(6.0));
- }
-
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/plant/DCMotor.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/system/plant/DCMotor.java
deleted file mode 100644
index 2e95e3f..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/plant/DCMotor.java
+++ /dev/null
@@ -1,182 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system.plant;
-
-import edu.wpi.first.wpilibj.util.Units;
-
-/**
- * Holds the constants for a DC motor.
- */
-public class DCMotor {
- public final double m_nominalVoltageVolts;
- public final double m_stallTorqueNewtonMeters;
- public final double m_stallCurrentAmps;
- public final double m_freeCurrentAmps;
- public final double m_freeSpeedRadPerSec;
- @SuppressWarnings("MemberName")
- public final double m_rOhms;
- @SuppressWarnings("MemberName")
- public final double m_KvRadPerSecPerVolt;
- @SuppressWarnings("MemberName")
- public final double m_KtNMPerAmp;
-
-
- /**
- * Constructs a DC motor.
- *
- * @param nominalVoltageVolts Voltage at which the motor constants were measured.
- * @param stallTorqueNewtonMeters Current draw when stalled.
- * @param stallCurrentAmps Current draw when stalled.
- * @param freeCurrentAmps Current draw under no load.
- * @param freeSpeedRadPerSec Angular velocity under no load.
- */
- public DCMotor(double nominalVoltageVolts,
- double stallTorqueNewtonMeters,
- double stallCurrentAmps,
- double freeCurrentAmps,
- double freeSpeedRadPerSec) {
- this.m_nominalVoltageVolts = nominalVoltageVolts;
- this.m_stallTorqueNewtonMeters = stallTorqueNewtonMeters;
- this.m_stallCurrentAmps = stallCurrentAmps;
- this.m_freeCurrentAmps = freeCurrentAmps;
- this.m_freeSpeedRadPerSec = freeSpeedRadPerSec;
-
- this.m_rOhms = nominalVoltageVolts / stallCurrentAmps;
- this.m_KvRadPerSecPerVolt = freeSpeedRadPerSec / (nominalVoltageVolts - m_rOhms
- * freeCurrentAmps);
- this.m_KtNMPerAmp = stallTorqueNewtonMeters / stallCurrentAmps;
- }
-
- /**
- * Estimate the current being drawn by this motor.
- *
- * @param speedRadiansPerSec The speed of the rotor.
- * @param voltageInputVolts The input voltage.
- */
- public double getCurrent(double speedRadiansPerSec, double voltageInputVolts) {
- return -1.0 / m_KvRadPerSecPerVolt / m_rOhms * speedRadiansPerSec
- + 1.0 / m_rOhms * voltageInputVolts;
- }
-
- /**
- * Return a gearbox of CIM motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getCIM(int numMotors) {
- return new DCMotor(12,
- 2.42 * numMotors, 133,
- 2.7, Units.rotationsPerMinuteToRadiansPerSecond(5310));
- }
-
- /**
- * Return a gearbox of 775Pro motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getVex775Pro(int numMotors) {
- return gearbox(new DCMotor(12,
- 0.71, 134,
- 0.7, Units.rotationsPerMinuteToRadiansPerSecond(18730)), numMotors);
- }
-
- /**
- * Return a gearbox of NEO motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getNEO(int numMotors) {
- return gearbox(new DCMotor(12, 2.6,
- 105, 1.8, Units.rotationsPerMinuteToRadiansPerSecond(5676)), numMotors);
- }
-
- /**
- * Return a gearbox of MiniCIM motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getMiniCIM(int numMotors) {
- return gearbox(new DCMotor(12, 1.41, 89, 3,
- Units.rotationsPerMinuteToRadiansPerSecond(5840)), numMotors);
- }
-
- /**
- * Return a gearbox of Bag motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getBag(int numMotors) {
- return gearbox(new DCMotor(12, 0.43, 53, 1.8,
- Units.rotationsPerMinuteToRadiansPerSecond(13180)), numMotors);
- }
-
- /**
- * Return a gearbox of Andymark RS775-125 motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getAndymarkRs775_125(int numMotors) {
- return gearbox(new DCMotor(12, 0.28, 18, 1.6,
- Units.rotationsPerMinuteToRadiansPerSecond(5800.0)), numMotors);
- }
-
- /**
- * Return a gearbox of Banebots RS775 motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getBanebotsRs775(int numMotors) {
- return gearbox(new DCMotor(12, 0.72, 97, 2.7,
- Units.rotationsPerMinuteToRadiansPerSecond(13050.0)), numMotors);
- }
-
- /**
- * Return a gearbox of Andymark 9015 motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getAndymark9015(int numMotors) {
- return gearbox(new DCMotor(12, 0.36, 71, 3.7,
- Units.rotationsPerMinuteToRadiansPerSecond(14270.0)), numMotors);
- }
-
- /**
- * Return a gearbox of Banebots RS 550 motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getBanebotsRs550(int numMotors) {
- return gearbox(new DCMotor(12, 0.38, 84, 0.4,
- Units.rotationsPerMinuteToRadiansPerSecond(19000.0)), numMotors);
- }
-
- /**
- * Return a gearbox of Neo 550 motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getNeo550(int numMotors) {
- return gearbox(new DCMotor(12, 0.97, 100, 1.4,
- Units.rotationsPerMinuteToRadiansPerSecond(11000.0)), numMotors);
- }
-
- /**
- * Return a gearbox of Falcon 500 motors.
- *
- * @param numMotors Number of motors in the gearbox.
- */
- public static DCMotor getFalcon500(int numMotors) {
- return gearbox(new DCMotor(12, 4.69, 257, 1.5,
- Units.rotationsPerMinuteToRadiansPerSecond(6380.0)), numMotors);
- }
-
- private static DCMotor gearbox(DCMotor motor, double numMotors) {
- return new DCMotor(motor.m_nominalVoltageVolts, motor.m_stallTorqueNewtonMeters * numMotors,
- motor.m_stallCurrentAmps, motor.m_freeCurrentAmps, motor.m_freeSpeedRadPerSec);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/plant/LinearSystemId.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/system/plant/LinearSystemId.java
deleted file mode 100644
index 25d1161..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/system/plant/LinearSystemId.java
+++ /dev/null
@@ -1,207 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system.plant;
-
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-
-public final class LinearSystemId {
- private LinearSystemId() {
- // Utility class
- }
-
- /**
- * Create a state-space model of an elevator system.
- *
- * @param motor The motor (or gearbox) attached to the arm.
- * @param massKg The mass of the elevator carriage, in kilograms.
- * @param radiusMeters The radius of thd driving drum of the elevator, in meters.
- * @param G The reduction between motor and drum, as a ratio of output to input.
- * @return A LinearSystem representing the given characterized constants.
- */
- @SuppressWarnings("ParameterName")
- public static LinearSystem<N2, N1, N1> createElevatorSystem(DCMotor motor, double massKg,
- double radiusMeters, double G) {
- return new LinearSystem<>(
- Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1,
- 0, -Math.pow(G, 2) * motor.m_KtNMPerAmp
- / (motor.m_rOhms * radiusMeters * radiusMeters * massKg
- * motor.m_KvRadPerSecPerVolt)),
- VecBuilder.fill(
- 0, G * motor.m_KtNMPerAmp / (motor.m_rOhms * radiusMeters * massKg)),
- Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0),
- new Matrix<>(Nat.N1(), Nat.N1()));
- }
-
- /**
- * Create a state-space model of a flywheel system.
- *
- * @param motor The motor (or gearbox) attached to the arm.
- * @param jKgMetersSquared The moment of inertia J of the flywheel.
- * @param G The reduction between motor and drum, as a ratio of output to input.
- * @return A LinearSystem representing the given characterized constants.
- */
- @SuppressWarnings("ParameterName")
- public static LinearSystem<N1, N1, N1> createFlywheelSystem(DCMotor motor,
- double jKgMetersSquared,
- double G) {
- return new LinearSystem<>(
- VecBuilder.fill(
- -G * G * motor.m_KtNMPerAmp
- / (motor.m_KvRadPerSecPerVolt * motor.m_rOhms * jKgMetersSquared)),
- VecBuilder.fill(G * motor.m_KtNMPerAmp
- / (motor.m_rOhms * jKgMetersSquared)),
- Matrix.eye(Nat.N1()),
- new Matrix<>(Nat.N1(), Nat.N1()));
- }
-
- /**
- * Create a state-space model of a differential drive drivetrain. In this model, the
- * states are [v_left, v_right]^T, inputs are [V_left, V_right]^T and outputs are
- * [v_left, v_right]^T.
- *
- * @param motor the gearbox representing the motors driving the drivetrain.
- * @param massKg the mass of the robot.
- * @param rMeters the radius of the wheels in meters.
- * @param rbMeters the radius of the base (half the track width) in meters.
- * @param JKgMetersSquared the moment of inertia of the robot.
- * @param G the gearing reduction as output over input.
- * @return A LinearSystem representing a differential drivetrain.
- */
- @SuppressWarnings({"LocalVariableName", "ParameterName"})
- public static LinearSystem<N2, N2, N2> createDrivetrainVelocitySystem(DCMotor motor,
- double massKg,
- double rMeters,
- double rbMeters,
- double JKgMetersSquared,
- double G) {
- var C1 =
- -(G * G) * motor.m_KtNMPerAmp
- / (motor.m_KvRadPerSecPerVolt * motor.m_rOhms * rMeters * rMeters);
- var C2 = G * motor.m_KtNMPerAmp / (motor.m_rOhms * rMeters);
-
- final double C3 = 1 / massKg + rbMeters * rbMeters / JKgMetersSquared;
- final double C4 = 1 / massKg - rbMeters * rbMeters / JKgMetersSquared;
- var A = Matrix.mat(Nat.N2(), Nat.N2()).fill(
- C3 * C1,
- C4 * C1,
- C4 * C1,
- C3 * C1);
- var B = Matrix.mat(Nat.N2(), Nat.N2()).fill(
- C3 * C2,
- C4 * C2,
- C4 * C2,
- C3 * C2);
- var C = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 0.0, 0.0, 1.0);
- var D = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 0.0, 0.0, 0.0);
-
- return new LinearSystem<>(A, B, C, D);
- }
-
- /**
- * Create a state-space model of a single jointed arm system.
- *
- * @param motor The motor (or gearbox) attached to the arm.
- * @param jKgSquaredMeters The moment of inertia J of the arm.
- * @param G the gearing between the motor and arm, in output over input.
- * Most of the time this will be greater than 1.
- * @return A LinearSystem representing the given characterized constants.
- */
- @SuppressWarnings("ParameterName")
- public static LinearSystem<N2, N1, N1> createSingleJointedArmSystem(DCMotor motor,
- double jKgSquaredMeters,
- double G) {
- return new LinearSystem<>(
- Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1,
- 0, -Math.pow(G, 2) * motor.m_KtNMPerAmp
- / (motor.m_KvRadPerSecPerVolt * motor.m_rOhms * jKgSquaredMeters)),
- VecBuilder.fill(0, G * motor.m_KtNMPerAmp
- / (motor.m_rOhms * jKgSquaredMeters)),
- Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0),
- new Matrix<>(Nat.N1(), Nat.N1()));
- }
-
- /**
- * Identify a velocity system from it's kV (volts/(unit/sec)) and kA (volts/(unit/sec^2).
- * These constants cam be found using frc-characterization.
- *
- * <p>The distance unit you choose MUST be an SI unit (i.e. meters or radians). You can use
- * the {@link edu.wpi.first.wpilibj.util.Units} class for converting between unit types.
- *
- * @param kV The velocity gain, in volts per (units per second)
- * @param kA The acceleration gain, in volts per (units per second squared)
- * @return A LinearSystem representing the given characterized constants.
- * @see <a href="https://github.com/wpilibsuite/frc-characterization">
- * https://github.com/wpilibsuite/frc-characterization</a>
- */
- @SuppressWarnings("ParameterName")
- public static LinearSystem<N1, N1, N1> identifyVelocitySystem(double kV, double kA) {
- return new LinearSystem<>(
- VecBuilder.fill(-kV / kA),
- VecBuilder.fill(1.0 / kA),
- VecBuilder.fill(1.0),
- VecBuilder.fill(0.0));
- }
-
- /**
- * Identify a position system from it's kV (volts/(unit/sec)) and kA (volts/(unit/sec^2).
- * These constants cam be found using frc-characterization.
- *
- * <p>The distance unit you choose MUST be an SI unit (i.e. meters or radians). You can use
- * the {@link edu.wpi.first.wpilibj.util.Units} class for converting between unit types.
- *
- * @param kV The velocity gain, in volts per (units per second)
- * @param kA The acceleration gain, in volts per (units per second squared)
- * @return A LinearSystem representing the given characterized constants.
- * @see <a href="https://github.com/wpilibsuite/frc-characterization">
- * https://github.com/wpilibsuite/frc-characterization</a>
- */
- @SuppressWarnings("ParameterName")
- public static LinearSystem<N2, N1, N1> identifyPositionSystem(double kV, double kA) {
- return new LinearSystem<>(
- Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 1.0, 0.0, -kV / kA),
- VecBuilder.fill(0.0, 1.0 / kA),
- Matrix.mat(Nat.N1(), Nat.N2()).fill(1.0, 0.0),
- VecBuilder.fill(0.0));
- }
-
- /**
- * Identify a standard differential drive drivetrain, given the drivetrain's
- * kV and kA in both linear (volts/(meter/sec) and volts/(meter/sec^2)) and
- * angular (volts/(radian/sec) and volts/(radian/sec^2)) cases. This can be
- * found using frc-characterization.
- *
- * @param kVLinear The linear velocity gain, volts per (meter per second).
- * @param kALinear The linear acceleration gain, volts per (meter per second squared).
- * @param kVAngular The angular velocity gain, volts per (radians per second).
- * @param kAAngular The angular acceleration gain, volts per (radians per second squared).
- * @return A LinearSystem representing the given characterized constants.
- * @see <a href="https://github.com/wpilibsuite/frc-characterization">
- * https://github.com/wpilibsuite/frc-characterization</a>
- */
- @SuppressWarnings("ParameterName")
- public static LinearSystem<N2, N2, N2> identifyDrivetrainSystem(
- double kVLinear, double kALinear, double kVAngular, double kAAngular) {
-
- final double c = 0.5 / (kALinear * kAAngular);
- final double A1 = c * (-kALinear * kVAngular - kVLinear * kAAngular);
- final double A2 = c * (kALinear * kVAngular - kVLinear * kAAngular);
- final double B1 = c * (kALinear + kAAngular);
- final double B2 = c * (kAAngular - kALinear);
-
- return new LinearSystem<>(
- Matrix.mat(Nat.N2(), Nat.N2()).fill(A1, A2, A2, A1),
- Matrix.mat(Nat.N2(), Nat.N2()).fill(B1, B2, B2, B1),
- Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1),
- Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 0, 0, 0));
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/Trajectory.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/Trajectory.java
deleted file mode 100644
index 7de2d84..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/Trajectory.java
+++ /dev/null
@@ -1,349 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Objects;
-import java.util.stream.Collectors;
-
-import com.fasterxml.jackson.annotation.JsonProperty;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Transform2d;
-
-/**
- * Represents a time-parameterized trajectory. The trajectory contains of
- * various States that represent the pose, curvature, time elapsed, velocity,
- * and acceleration at that point.
- */
-public class Trajectory {
- private final double m_totalTimeSeconds;
- private final List<State> m_states;
-
- /**
- * Constructs an empty trajectory.
- */
- public Trajectory() {
- m_states = new ArrayList<>();
- m_totalTimeSeconds = 0.0;
- }
-
- /**
- * Constructs a trajectory from a vector of states.
- *
- * @param states A vector of states.
- */
- public Trajectory(final List<State> states) {
- m_states = states;
- m_totalTimeSeconds = m_states.get(m_states.size() - 1).timeSeconds;
- }
-
- /**
- * Linearly interpolates between two values.
- *
- * @param startValue The start value.
- * @param endValue The end value.
- * @param t The fraction for interpolation.
- * @return The interpolated value.
- */
- @SuppressWarnings("ParameterName")
- private static double lerp(double startValue, double endValue, double t) {
- return startValue + (endValue - startValue) * t;
- }
-
- /**
- * Linearly interpolates between two poses.
- *
- * @param startValue The start pose.
- * @param endValue The end pose.
- * @param t The fraction for interpolation.
- * @return The interpolated pose.
- */
- @SuppressWarnings("ParameterName")
- private static Pose2d lerp(Pose2d startValue, Pose2d endValue, double t) {
- return startValue.plus((endValue.minus(startValue)).times(t));
- }
-
- /**
- * Returns the initial pose of the trajectory.
- *
- * @return The initial pose of the trajectory.
- */
- public Pose2d getInitialPose() {
- return sample(0).poseMeters;
- }
-
- /**
- * Returns the overall duration of the trajectory.
- *
- * @return The duration of the trajectory.
- */
- public double getTotalTimeSeconds() {
- return m_totalTimeSeconds;
- }
-
- /**
- * Return the states of the trajectory.
- *
- * @return The states of the trajectory.
- */
- public List<State> getStates() {
- return m_states;
- }
-
- /**
- * Sample the trajectory at a point in time.
- *
- * @param timeSeconds The point in time since the beginning of the trajectory to sample.
- * @return The state at that point in time.
- */
- public State sample(double timeSeconds) {
- if (timeSeconds <= m_states.get(0).timeSeconds) {
- return m_states.get(0);
- }
- if (timeSeconds >= m_totalTimeSeconds) {
- return m_states.get(m_states.size() - 1);
- }
-
- // To get the element that we want, we will use a binary search algorithm
- // instead of iterating over a for-loop. A binary search is O(std::log(n))
- // whereas searching using a loop is O(n).
-
- // This starts at 1 because we use the previous state later on for
- // interpolation.
- int low = 1;
- int high = m_states.size() - 1;
-
- while (low != high) {
- int mid = (low + high) / 2;
- if (m_states.get(mid).timeSeconds < timeSeconds) {
- // This index and everything under it are less than the requested
- // timestamp. Therefore, we can discard them.
- low = mid + 1;
- } else {
- // t is at least as large as the element at this index. This means that
- // anything after it cannot be what we are looking for.
- high = mid;
- }
- }
-
- // High and Low should be the same.
-
- // The sample's timestamp is now greater than or equal to the requested
- // timestamp. If it is greater, we need to interpolate between the
- // previous state and the current state to get the exact state that we
- // want.
- final State sample = m_states.get(low);
- final State prevSample = m_states.get(low - 1);
-
- // If the difference in states is negligible, then we are spot on!
- if (Math.abs(sample.timeSeconds - prevSample.timeSeconds) < 1E-9) {
- return sample;
- }
- // Interpolate between the two states for the state that we want.
- return prevSample.interpolate(sample,
- (timeSeconds - prevSample.timeSeconds) / (sample.timeSeconds - prevSample.timeSeconds));
- }
-
- /**
- * Transforms all poses in the trajectory by the given transform. This is
- * useful for converting a robot-relative trajectory into a field-relative
- * trajectory. This works with respect to the first pose in the trajectory.
- *
- * @param transform The transform to transform the trajectory by.
- * @return The transformed trajectory.
- */
- @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
- public Trajectory transformBy(Transform2d transform) {
- var firstState = m_states.get(0);
- var firstPose = firstState.poseMeters;
-
- // Calculate the transformed first pose.
- var newFirstPose = firstPose.plus(transform);
- List<State> newStates = new ArrayList<>();
-
- newStates.add(new State(
- firstState.timeSeconds, firstState.velocityMetersPerSecond,
- firstState.accelerationMetersPerSecondSq, newFirstPose, firstState.curvatureRadPerMeter
- ));
-
- for (int i = 1; i < m_states.size(); i++) {
- var state = m_states.get(i);
- // We are transforming relative to the coordinate frame of the new initial pose.
- newStates.add(new State(
- state.timeSeconds, state.velocityMetersPerSecond,
- state.accelerationMetersPerSecondSq, newFirstPose.plus(state.poseMeters.minus(firstPose)),
- state.curvatureRadPerMeter
- ));
- }
-
- return new Trajectory(newStates);
- }
-
- /**
- * Transforms all poses in the trajectory so that they are relative to the
- * given pose. This is useful for converting a field-relative trajectory
- * into a robot-relative trajectory.
- *
- * @param pose The pose that is the origin of the coordinate frame that
- * the current trajectory will be transformed into.
- * @return The transformed trajectory.
- */
- public Trajectory relativeTo(Pose2d pose) {
- return new Trajectory(m_states.stream().map(state -> new State(state.timeSeconds,
- state.velocityMetersPerSecond, state.accelerationMetersPerSecondSq,
- state.poseMeters.relativeTo(pose), state.curvatureRadPerMeter))
- .collect(Collectors.toList()));
- }
-
- /**
- * Represents a time-parameterized trajectory. The trajectory contains of
- * various States that represent the pose, curvature, time elapsed, velocity,
- * and acceleration at that point.
- */
- @SuppressWarnings("MemberName")
- public static class State {
- // The time elapsed since the beginning of the trajectory.
- @JsonProperty("time")
- public double timeSeconds;
-
- // The speed at that point of the trajectory.
- @JsonProperty("velocity")
- public double velocityMetersPerSecond;
-
- // The acceleration at that point of the trajectory.
- @JsonProperty("acceleration")
- public double accelerationMetersPerSecondSq;
-
- // The pose at that point of the trajectory.
- @JsonProperty("pose")
- public Pose2d poseMeters;
-
- // The curvature at that point of the trajectory.
- @JsonProperty("curvature")
- public double curvatureRadPerMeter;
-
- public State() {
- poseMeters = new Pose2d();
- }
-
- /**
- * Constructs a State with the specified parameters.
- *
- * @param timeSeconds The time elapsed since the beginning of the trajectory.
- * @param velocityMetersPerSecond The speed at that point of the trajectory.
- * @param accelerationMetersPerSecondSq The acceleration at that point of the trajectory.
- * @param poseMeters The pose at that point of the trajectory.
- * @param curvatureRadPerMeter The curvature at that point of the trajectory.
- */
- public State(double timeSeconds, double velocityMetersPerSecond,
- double accelerationMetersPerSecondSq, Pose2d poseMeters,
- double curvatureRadPerMeter) {
- this.timeSeconds = timeSeconds;
- this.velocityMetersPerSecond = velocityMetersPerSecond;
- this.accelerationMetersPerSecondSq = accelerationMetersPerSecondSq;
- this.poseMeters = poseMeters;
- this.curvatureRadPerMeter = curvatureRadPerMeter;
- }
-
- /**
- * Interpolates between two States.
- *
- * @param endValue The end value for the interpolation.
- * @param i The interpolant (fraction).
- * @return The interpolated state.
- */
- @SuppressWarnings("ParameterName")
- State interpolate(State endValue, double i) {
- // Find the new t value.
- final double newT = lerp(timeSeconds, endValue.timeSeconds, i);
-
- // Find the delta time between the current state and the interpolated state.
- final double deltaT = newT - timeSeconds;
-
- // If delta time is negative, flip the order of interpolation.
- if (deltaT < 0) {
- return endValue.interpolate(this, 1 - i);
- }
-
- // Check whether the robot is reversing at this stage.
- final boolean reversing = velocityMetersPerSecond < 0
- || Math.abs(velocityMetersPerSecond) < 1E-9 && accelerationMetersPerSecondSq < 0;
-
- // Calculate the new velocity
- // v_f = v_0 + at
- final double newV = velocityMetersPerSecond + (accelerationMetersPerSecondSq * deltaT);
-
- // Calculate the change in position.
- // delta_s = v_0 t + 0.5 at^2
- final double newS = (velocityMetersPerSecond * deltaT
- + 0.5 * accelerationMetersPerSecondSq * Math.pow(deltaT, 2)) * (reversing ? -1.0 : 1.0);
-
- // Return the new state. To find the new position for the new state, we need
- // to interpolate between the two endpoint poses. The fraction for
- // interpolation is the change in position (delta s) divided by the total
- // distance between the two endpoints.
- final double interpolationFrac = newS
- / endValue.poseMeters.getTranslation().getDistance(poseMeters.getTranslation());
-
- return new State(
- newT, newV, accelerationMetersPerSecondSq,
- lerp(poseMeters, endValue.poseMeters, interpolationFrac),
- lerp(curvatureRadPerMeter, endValue.curvatureRadPerMeter, interpolationFrac)
- );
- }
-
- @Override
- public String toString() {
- return String.format(
- "State(Sec: %.2f, Vel m/s: %.2f, Accel m/s/s: %.2f, Pose: %s, Curvature: %.2f)",
- timeSeconds, velocityMetersPerSecond, accelerationMetersPerSecondSq,
- poseMeters, curvatureRadPerMeter);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (this == obj) {
- return true;
- }
- if (!(obj instanceof State)) {
- return false;
- }
- State state = (State) obj;
- return Double.compare(state.timeSeconds, timeSeconds) == 0
- && Double.compare(state.velocityMetersPerSecond, velocityMetersPerSecond) == 0
- && Double.compare(state.accelerationMetersPerSecondSq,
- accelerationMetersPerSecondSq) == 0
- && Double.compare(state.curvatureRadPerMeter, curvatureRadPerMeter) == 0
- && Objects.equals(poseMeters, state.poseMeters);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(timeSeconds, velocityMetersPerSecond,
- accelerationMetersPerSecondSq, poseMeters, curvatureRadPerMeter);
- }
- }
-
- @Override
- public String toString() {
- String stateList = m_states.stream().map(State::toString).collect(Collectors.joining(", \n"));
- return String.format("Trajectory - Seconds: %.2f, States:\n%s", m_totalTimeSeconds, stateList);
- }
-
- @Override
- public int hashCode() {
- return m_states.hashCode();
- }
-
- @Override
- public boolean equals(Object obj) {
- return obj instanceof Trajectory && m_states.equals(((Trajectory) obj).getStates());
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryParameterizer.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryParameterizer.java
deleted file mode 100644
index 3b1d2af..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryParameterizer.java
+++ /dev/null
@@ -1,318 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-/*
- * MIT License
- *
- * Copyright (c) 2018 Team 254
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import edu.wpi.first.wpilibj.spline.PoseWithCurvature;
-import edu.wpi.first.wpilibj.trajectory.constraint.TrajectoryConstraint;
-
-/**
- * Class used to parameterize a trajectory by time.
- */
-public final class TrajectoryParameterizer {
- /**
- * Private constructor because this is a utility class.
- */
- private TrajectoryParameterizer() {
- }
-
- /**
- * Parameterize the trajectory by time. This is where the velocity profile is
- * generated.
- *
- * <p>The derivation of the algorithm used can be found
- * <a href="http://www2.informatik.uni-freiburg.de/~lau/students/Sprunk2008.pdf">
- * here</a>.
- *
- * @param points Reference to the spline points.
- * @param constraints A vector of various velocity and acceleration.
- * constraints.
- * @param startVelocityMetersPerSecond The start velocity for the trajectory.
- * @param endVelocityMetersPerSecond The end velocity for the trajectory.
- * @param maxVelocityMetersPerSecond The max velocity for the trajectory.
- * @param maxAccelerationMetersPerSecondSq The max acceleration for the trajectory.
- * @param reversed Whether the robot should move backwards.
- * Note that the robot will still move from
- * a -> b -> ... -> z as defined in the
- * waypoints.
- * @return The trajectory.
- */
- @SuppressWarnings({"PMD.ExcessiveMethodLength", "PMD.CyclomaticComplexity",
- "PMD.NPathComplexity", "PMD.AvoidInstantiatingObjectsInLoops"})
- public static Trajectory timeParameterizeTrajectory(
- List<PoseWithCurvature> points,
- List<TrajectoryConstraint> constraints,
- double startVelocityMetersPerSecond,
- double endVelocityMetersPerSecond,
- double maxVelocityMetersPerSecond,
- double maxAccelerationMetersPerSecondSq,
- boolean reversed
- ) {
- var constrainedStates = new ArrayList<ConstrainedState>(points.size());
- var predecessor = new ConstrainedState(points.get(0), 0, startVelocityMetersPerSecond,
- -maxAccelerationMetersPerSecondSq, maxAccelerationMetersPerSecondSq);
-
- // Forward pass
- for (int i = 0; i < points.size(); i++) {
- constrainedStates.add(new ConstrainedState());
- var constrainedState = constrainedStates.get(i);
- constrainedState.pose = points.get(i);
-
- // Begin constraining based on predecessor.
- double ds = constrainedState.pose.poseMeters.getTranslation().getDistance(
- predecessor.pose.poseMeters.getTranslation());
- constrainedState.distanceMeters = predecessor.distanceMeters + ds;
-
- // We may need to iterate to find the maximum end velocity and common
- // acceleration, since acceleration limits may be a function of velocity.
- while (true) {
- // Enforce global max velocity and max reachable velocity by global
- // acceleration limit. vf = std::sqrt(vi^2 + 2*a*d).
- constrainedState.maxVelocityMetersPerSecond = Math.min(
- maxVelocityMetersPerSecond,
- Math.sqrt(predecessor.maxVelocityMetersPerSecond
- * predecessor.maxVelocityMetersPerSecond
- + predecessor.maxAccelerationMetersPerSecondSq * ds * 2.0)
- );
-
- constrainedState.minAccelerationMetersPerSecondSq = -maxAccelerationMetersPerSecondSq;
- constrainedState.maxAccelerationMetersPerSecondSq = maxAccelerationMetersPerSecondSq;
-
- // At this point, the constrained state is fully constructed apart from
- // all the custom-defined user constraints.
- for (final var constraint : constraints) {
- constrainedState.maxVelocityMetersPerSecond = Math.min(
- constrainedState.maxVelocityMetersPerSecond,
- constraint.getMaxVelocityMetersPerSecond(
- constrainedState.pose.poseMeters, constrainedState.pose.curvatureRadPerMeter,
- constrainedState.maxVelocityMetersPerSecond)
- );
- }
-
- // Now enforce all acceleration limits.
- enforceAccelerationLimits(reversed, constraints, constrainedState);
-
- if (ds < 1E-6) {
- break;
- }
-
- // If the actual acceleration for this state is higher than the max
- // acceleration that we applied, then we need to reduce the max
- // acceleration of the predecessor and try again.
- double actualAcceleration = (constrainedState.maxVelocityMetersPerSecond
- * constrainedState.maxVelocityMetersPerSecond
- - predecessor.maxVelocityMetersPerSecond * predecessor.maxVelocityMetersPerSecond)
- / (ds * 2.0);
-
- // If we violate the max acceleration constraint, let's modify the
- // predecessor.
- if (constrainedState.maxAccelerationMetersPerSecondSq < actualAcceleration - 1E-6) {
- predecessor.maxAccelerationMetersPerSecondSq
- = constrainedState.maxAccelerationMetersPerSecondSq;
- } else {
- // Constrain the predecessor's max acceleration to the current
- // acceleration.
- if (actualAcceleration > predecessor.minAccelerationMetersPerSecondSq) {
- predecessor.maxAccelerationMetersPerSecondSq = actualAcceleration;
- }
- // If the actual acceleration is less than the predecessor's min
- // acceleration, it will be repaired in the backward pass.
- break;
- }
- }
- predecessor = constrainedState;
- }
-
- var successor = new ConstrainedState(points.get(points.size() - 1),
- constrainedStates.get(constrainedStates.size() - 1).distanceMeters,
- endVelocityMetersPerSecond,
- -maxAccelerationMetersPerSecondSq, maxAccelerationMetersPerSecondSq);
-
- // Backward pass
- for (int i = points.size() - 1; i >= 0; i--) {
- var constrainedState = constrainedStates.get(i);
- double ds = constrainedState.distanceMeters - successor.distanceMeters; // negative
-
- while (true) {
- // Enforce max velocity limit (reverse)
- // vf = std::sqrt(vi^2 + 2*a*d), where vi = successor.
- double newMaxVelocity = Math.sqrt(
- successor.maxVelocityMetersPerSecond * successor.maxVelocityMetersPerSecond
- + successor.minAccelerationMetersPerSecondSq * ds * 2.0
- );
-
- // No more limits to impose! This state can be finalized.
- if (newMaxVelocity >= constrainedState.maxVelocityMetersPerSecond) {
- break;
- }
-
- constrainedState.maxVelocityMetersPerSecond = newMaxVelocity;
-
- // Check all acceleration constraints with the new max velocity.
- enforceAccelerationLimits(reversed, constraints, constrainedState);
-
- if (ds > -1E-6) {
- break;
- }
-
- // If the actual acceleration for this state is lower than the min
- // acceleration, then we need to lower the min acceleration of the
- // successor and try again.
- double actualAcceleration = (constrainedState.maxVelocityMetersPerSecond
- * constrainedState.maxVelocityMetersPerSecond
- - successor.maxVelocityMetersPerSecond * successor.maxVelocityMetersPerSecond)
- / (ds * 2.0);
-
- if (constrainedState.minAccelerationMetersPerSecondSq > actualAcceleration + 1E-6) {
- successor.minAccelerationMetersPerSecondSq
- = constrainedState.minAccelerationMetersPerSecondSq;
- } else {
- successor.minAccelerationMetersPerSecondSq = actualAcceleration;
- break;
- }
- }
- successor = constrainedState;
- }
-
- // Now we can integrate the constrained states forward in time to obtain our
- // trajectory states.
- var states = new ArrayList<Trajectory.State>(points.size());
- double timeSeconds = 0.0;
- double distanceMeters = 0.0;
- double velocityMetersPerSecond = 0.0;
-
- for (int i = 0; i < constrainedStates.size(); i++) {
- final var state = constrainedStates.get(i);
-
- // Calculate the change in position between the current state and the previous
- // state.
- double ds = state.distanceMeters - distanceMeters;
-
- // Calculate the acceleration between the current state and the previous
- // state.
- double accel = (state.maxVelocityMetersPerSecond * state.maxVelocityMetersPerSecond
- - velocityMetersPerSecond * velocityMetersPerSecond) / (ds * 2);
-
- // Calculate dt
- double dt = 0.0;
- if (i > 0) {
- states.get(i - 1).accelerationMetersPerSecondSq = reversed ? -accel : accel;
- if (Math.abs(accel) > 1E-6) {
- // v_f = v_0 + a * t
- dt = (state.maxVelocityMetersPerSecond - velocityMetersPerSecond) / accel;
- } else if (Math.abs(velocityMetersPerSecond) > 1E-6) {
- // delta_x = v * t
- dt = ds / velocityMetersPerSecond;
- } else {
- throw new TrajectoryGenerationException("Something went wrong at iteration " + i
- + " of time parameterization.");
- }
- }
-
- velocityMetersPerSecond = state.maxVelocityMetersPerSecond;
- distanceMeters = state.distanceMeters;
-
- timeSeconds += dt;
-
- states.add(new Trajectory.State(
- timeSeconds,
- reversed ? -velocityMetersPerSecond : velocityMetersPerSecond,
- reversed ? -accel : accel,
- state.pose.poseMeters, state.pose.curvatureRadPerMeter
- ));
- }
-
- return new Trajectory(states);
- }
-
- private static void enforceAccelerationLimits(boolean reverse,
- List<TrajectoryConstraint> constraints,
- ConstrainedState state) {
-
- for (final var constraint : constraints) {
- double factor = reverse ? -1.0 : 1.0;
- final var minMaxAccel = constraint.getMinMaxAccelerationMetersPerSecondSq(
- state.pose.poseMeters, state.pose.curvatureRadPerMeter,
- state.maxVelocityMetersPerSecond * factor);
-
- if (minMaxAccel.minAccelerationMetersPerSecondSq
- > minMaxAccel.maxAccelerationMetersPerSecondSq) {
- throw new TrajectoryGenerationException("The constraint's min acceleration "
- + "was greater than its max acceleration.\n Offending Constraint: "
- + constraint.getClass().getName()
- + "\n If the offending constraint was packaged with WPILib, please file a bug report.");
- }
-
- state.minAccelerationMetersPerSecondSq = Math.max(state.minAccelerationMetersPerSecondSq,
- reverse ? -minMaxAccel.maxAccelerationMetersPerSecondSq
- : minMaxAccel.minAccelerationMetersPerSecondSq);
-
- state.maxAccelerationMetersPerSecondSq = Math.min(state.maxAccelerationMetersPerSecondSq,
- reverse ? -minMaxAccel.minAccelerationMetersPerSecondSq
- : minMaxAccel.maxAccelerationMetersPerSecondSq);
- }
-
- }
-
- @SuppressWarnings("MemberName")
- private static class ConstrainedState {
- PoseWithCurvature pose;
- double distanceMeters;
- double maxVelocityMetersPerSecond;
- double minAccelerationMetersPerSecondSq;
- double maxAccelerationMetersPerSecondSq;
-
- ConstrainedState(PoseWithCurvature pose, double distanceMeters,
- double maxVelocityMetersPerSecond,
- double minAccelerationMetersPerSecondSq,
- double maxAccelerationMetersPerSecondSq) {
- this.pose = pose;
- this.distanceMeters = distanceMeters;
- this.maxVelocityMetersPerSecond = maxVelocityMetersPerSecond;
- this.minAccelerationMetersPerSecondSq = minAccelerationMetersPerSecondSq;
- this.maxAccelerationMetersPerSecondSq = maxAccelerationMetersPerSecondSq;
- }
-
- ConstrainedState() {
- pose = new PoseWithCurvature();
- }
- }
-
- @SuppressWarnings("serial")
- public static class TrajectoryGenerationException extends RuntimeException {
- public TrajectoryGenerationException(String message) {
- super(message);
- }
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryUtil.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryUtil.java
deleted file mode 100644
index 0cd0f49..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/TrajectoryUtil.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.util.Arrays;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.ObjectReader;
-import com.fasterxml.jackson.databind.ObjectWriter;
-
-public final class TrajectoryUtil {
- private static final ObjectReader READER = new ObjectMapper().readerFor(Trajectory.State[].class);
- private static final ObjectWriter WRITER = new ObjectMapper().writerFor(Trajectory.State[].class);
-
- private TrajectoryUtil() {
- throw new UnsupportedOperationException("This is a utility class!");
- }
-
- /**
- * Imports a Trajectory from a PathWeaver-style JSON file.
- * @param path the path of the json file to import from
- * @return The trajectory represented by the file.
- * @throws IOException if reading from the file fails
- */
- public static Trajectory fromPathweaverJson(Path path) throws IOException {
- try (BufferedReader reader = Files.newBufferedReader(path)) {
- Trajectory.State[] state = READER.readValue(reader);
- return new Trajectory(Arrays.asList(state));
- }
- }
-
- /**
- * Exports a Trajectory to a PathWeaver-style JSON file.
- * @param trajectory the trajectory to export
- * @param path the path of the file to export to
- * @throws IOException if writing to the file fails
- */
- public static void toPathweaverJson(Trajectory trajectory, Path path) throws IOException {
- Files.createDirectories(path.getParent());
- try (BufferedWriter writer = Files.newBufferedWriter(path)) {
- WRITER.writeValue(writer, trajectory.getStates().toArray(new Trajectory.State[0]));
- }
- }
-
- /**
- * Deserializes a Trajectory from PathWeaver-style JSON.
- * @param json the string containing the serialized JSON
- * @return the trajectory represented by the JSON
- * @throws JsonProcessingException if deserializing the JSON fails
- */
- public static Trajectory deserializeTrajectory(String json) throws JsonProcessingException {
- Trajectory.State[] state = READER.readValue(json);
- return new Trajectory(Arrays.asList(state));
- }
-
- /**
- * Serializes a Trajectory to PathWeaver-style JSON.
- * @param trajectory the trajectory to export
- * @return the string containing the serialized JSON
- * @throws JsonProcessingException if serializing the Trajectory fails
- */
- public static String serializeTrajectory(Trajectory trajectory) throws JsonProcessingException {
- return WRITER.writeValueAsString(trajectory.getStates().toArray(new Trajectory.State[0]));
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/CentripetalAccelerationConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/CentripetalAccelerationConstraint.java
deleted file mode 100644
index 0b87a64..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/CentripetalAccelerationConstraint.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-
-/**
- * A constraint on the maximum absolute centripetal acceleration allowed when
- * traversing a trajectory. The centripetal acceleration of a robot is defined
- * as the velocity squared divided by the radius of curvature.
- *
- * <p>Effectively, limiting the maximum centripetal acceleration will cause the
- * robot to slow down around tight turns, making it easier to track trajectories
- * with sharp turns.
- */
-public class CentripetalAccelerationConstraint implements TrajectoryConstraint {
- private final double m_maxCentripetalAccelerationMetersPerSecondSq;
-
- /**
- * Constructs a centripetal acceleration constraint.
- *
- * @param maxCentripetalAccelerationMetersPerSecondSq The max centripetal acceleration.
- */
- public CentripetalAccelerationConstraint(double maxCentripetalAccelerationMetersPerSecondSq) {
- m_maxCentripetalAccelerationMetersPerSecondSq = maxCentripetalAccelerationMetersPerSecondSq;
- }
-
- /**
- * Returns the max velocity given the current pose and curvature.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
- * constraints are applied.
- * @return The absolute maximum velocity.
- */
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- // ac = v^2 / r
- // k (curvature) = 1 / r
-
- // therefore, ac = v^2 * k
- // ac / k = v^2
- // v = std::sqrt(ac / k)
-
- return Math.sqrt(m_maxCentripetalAccelerationMetersPerSecondSq
- / Math.abs(curvatureRadPerMeter));
- }
-
- /**
- * Returns the minimum and maximum allowable acceleration for the trajectory
- * given pose, curvature, and speed.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The speed at the current point in the trajectory.
- * @return The min and max acceleration bounds.
- */
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- // The acceleration of the robot has no impact on the centripetal acceleration
- // of the robot.
- return new MinMax();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/DifferentialDriveKinematicsConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/DifferentialDriveKinematicsConstraint.java
deleted file mode 100644
index 67cddcf..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/DifferentialDriveKinematicsConstraint.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.kinematics.ChassisSpeeds;
-import edu.wpi.first.wpilibj.kinematics.DifferentialDriveKinematics;
-
-/**
- * A class that enforces constraints on the differential drive kinematics.
- * This can be used to ensure that the trajectory is constructed so that the
- * commanded velocities for both sides of the drivetrain stay below a certain
- * limit.
- */
-public class DifferentialDriveKinematicsConstraint implements TrajectoryConstraint {
- private final double m_maxSpeedMetersPerSecond;
- private final DifferentialDriveKinematics m_kinematics;
-
- /**
- * Constructs a differential drive dynamics constraint.
- *
- * @param kinematics A kinematics component describing the drive geometry.
- * @param maxSpeedMetersPerSecond The max speed that a side of the robot can travel at.
- */
- public DifferentialDriveKinematicsConstraint(final DifferentialDriveKinematics kinematics,
- double maxSpeedMetersPerSecond) {
- m_maxSpeedMetersPerSecond = maxSpeedMetersPerSecond;
- m_kinematics = kinematics;
- }
-
-
- /**
- * Returns the max velocity given the current pose and curvature.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
- * constraints are applied.
- * @return The absolute maximum velocity.
- */
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- // Create an object to represent the current chassis speeds.
- var chassisSpeeds = new ChassisSpeeds(velocityMetersPerSecond,
- 0, velocityMetersPerSecond * curvatureRadPerMeter);
-
- // Get the wheel speeds and normalize them to within the max velocity.
- var wheelSpeeds = m_kinematics.toWheelSpeeds(chassisSpeeds);
- wheelSpeeds.normalize(m_maxSpeedMetersPerSecond);
-
- // Return the new linear chassis speed.
- return m_kinematics.toChassisSpeeds(wheelSpeeds).vxMetersPerSecond;
- }
-
- /**
- * Returns the minimum and maximum allowable acceleration for the trajectory
- * given pose, curvature, and speed.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The speed at the current point in the trajectory.
- * @return The min and max acceleration bounds.
- */
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- return new MinMax();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/DifferentialDriveVoltageConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/DifferentialDriveVoltageConstraint.java
deleted file mode 100644
index 9e28b0c..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/DifferentialDriveVoltageConstraint.java
+++ /dev/null
@@ -1,126 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.controller.SimpleMotorFeedforward;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.kinematics.ChassisSpeeds;
-import edu.wpi.first.wpilibj.kinematics.DifferentialDriveKinematics;
-
-import static edu.wpi.first.wpiutil.ErrorMessages.requireNonNullParam;
-
-/**
- * A class that enforces constraints on differential drive voltage expenditure based on the motor
- * dynamics and the drive kinematics. Ensures that the acceleration of any wheel of the robot
- * while following the trajectory is never higher than what can be achieved with the given
- * maximum voltage.
- */
-public class DifferentialDriveVoltageConstraint implements TrajectoryConstraint {
- private final SimpleMotorFeedforward m_feedforward;
- private final DifferentialDriveKinematics m_kinematics;
- private final double m_maxVoltage;
-
- /**
- * Creates a new DifferentialDriveVoltageConstraint.
- *
- * @param feedforward A feedforward component describing the behavior of the drive.
- * @param kinematics A kinematics component describing the drive geometry.
- * @param maxVoltage The maximum voltage available to the motors while following the path.
- * Should be somewhat less than the nominal battery voltage (12V) to account
- * for "voltage sag" due to current draw.
- */
- public DifferentialDriveVoltageConstraint(SimpleMotorFeedforward feedforward,
- DifferentialDriveKinematics kinematics,
- double maxVoltage) {
- m_feedforward = requireNonNullParam(feedforward, "feedforward",
- "DifferentialDriveVoltageConstraint");
- m_kinematics = requireNonNullParam(kinematics, "kinematics",
- "DifferentialDriveVoltageConstraint");
- m_maxVoltage = maxVoltage;
- }
-
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- return Double.POSITIVE_INFINITY;
- }
-
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
-
- var wheelSpeeds = m_kinematics.toWheelSpeeds(new ChassisSpeeds(velocityMetersPerSecond, 0,
- velocityMetersPerSecond
- * curvatureRadPerMeter));
-
- double maxWheelSpeed = Math.max(wheelSpeeds.leftMetersPerSecond,
- wheelSpeeds.rightMetersPerSecond);
- double minWheelSpeed = Math.min(wheelSpeeds.leftMetersPerSecond,
- wheelSpeeds.rightMetersPerSecond);
-
- // Calculate maximum/minimum possible accelerations from motor dynamics
- // and max/min wheel speeds
- double maxWheelAcceleration =
- m_feedforward.maxAchievableAcceleration(m_maxVoltage, maxWheelSpeed);
- double minWheelAcceleration =
- m_feedforward.minAchievableAcceleration(m_maxVoltage, minWheelSpeed);
-
- // Robot chassis turning on radius = 1/|curvature|. Outer wheel has radius
- // increased by half of the trackwidth T. Inner wheel has radius decreased
- // by half of the trackwidth. Achassis / radius = Aouter / (radius + T/2), so
- // Achassis = Aouter * radius / (radius + T/2) = Aouter / (1 + |curvature|T/2).
- // Inner wheel is similar.
-
- // sgn(speed) term added to correctly account for which wheel is on
- // outside of turn:
- // If moving forward, max acceleration constraint corresponds to wheel on outside of turn
- // If moving backward, max acceleration constraint corresponds to wheel on inside of turn
-
- // When velocity is zero, then wheel velocities are uniformly zero (robot cannot be
- // turning on its center) - we have to treat this as a special case, as it breaks
- // the signum function. Both max and min acceleration are *reduced in magnitude*
- // in this case.
-
- double maxChassisAcceleration;
- double minChassisAcceleration;
-
- if (velocityMetersPerSecond == 0) {
- maxChassisAcceleration =
- maxWheelAcceleration
- / (1 + m_kinematics.trackWidthMeters * Math.abs(curvatureRadPerMeter) / 2);
- minChassisAcceleration =
- minWheelAcceleration
- / (1 + m_kinematics.trackWidthMeters * Math.abs(curvatureRadPerMeter) / 2);
- } else {
- maxChassisAcceleration =
- maxWheelAcceleration
- / (1 + m_kinematics.trackWidthMeters * Math.abs(curvatureRadPerMeter)
- * Math.signum(velocityMetersPerSecond) / 2);
- minChassisAcceleration =
- minWheelAcceleration
- / (1 - m_kinematics.trackWidthMeters * Math.abs(curvatureRadPerMeter)
- * Math.signum(velocityMetersPerSecond) / 2);
- }
-
- // When turning about a point inside of the wheelbase (i.e. radius less than half
- // the trackwidth), the inner wheel's direction changes, but the magnitude remains
- // the same. The formula above changes sign for the inner wheel when this happens.
- // We can accurately account for this by simply negating the inner wheel.
-
- if ((m_kinematics.trackWidthMeters / 2) > (1 / Math.abs(curvatureRadPerMeter))) {
- if (velocityMetersPerSecond > 0) {
- minChassisAcceleration = -minChassisAcceleration;
- } else if (velocityMetersPerSecond < 0) {
- maxChassisAcceleration = -maxChassisAcceleration;
- }
- }
-
- return new MinMax(minChassisAcceleration, maxChassisAcceleration);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/EllipticalRegionConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/EllipticalRegionConstraint.java
deleted file mode 100644
index eb9f7e7..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/EllipticalRegionConstraint.java
+++ /dev/null
@@ -1,80 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
-/**
- * Enforces a particular constraint only within an elliptical region.
- */
-public class EllipticalRegionConstraint implements TrajectoryConstraint {
- private final Translation2d m_center;
- private final Translation2d m_radii;
- private final TrajectoryConstraint m_constraint;
-
- /**
- * Constructs a new EllipticalRegionConstraint.
- *
- * @param center The center of the ellipse in which to enforce the constraint.
- * @param xWidth The width of the ellipse in which to enforce the constraint.
- * @param yWidth The height of the ellipse in which to enforce the constraint.
- * @param rotation The rotation to apply to all radii around the origin.
- * @param constraint The constraint to enforce when the robot is within the region.
- */
- @SuppressWarnings("ParameterName")
- public EllipticalRegionConstraint(Translation2d center, double xWidth, double yWidth,
- Rotation2d rotation, TrajectoryConstraint constraint) {
- m_center = center;
- m_radii = new Translation2d(xWidth / 2.0, yWidth / 2.0).rotateBy(rotation);
- m_constraint = constraint;
- }
-
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- if (isPoseInRegion(poseMeters)) {
- return m_constraint.getMaxVelocityMetersPerSecond(poseMeters, curvatureRadPerMeter,
- velocityMetersPerSecond);
- } else {
- return Double.POSITIVE_INFINITY;
- }
- }
-
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- if (isPoseInRegion(poseMeters)) {
- return m_constraint.getMinMaxAccelerationMetersPerSecondSq(poseMeters,
- curvatureRadPerMeter, velocityMetersPerSecond);
- } else {
- return new MinMax();
- }
- }
-
- /**
- * Returns whether the specified robot pose is within the region that the constraint
- * is enforced in.
- *
- * @param robotPose The robot pose.
- * @return Whether the robot pose is within the constraint region.
- */
- public boolean isPoseInRegion(Pose2d robotPose) {
- // The region (disk) bounded by the ellipse is given by the equation:
- // ((x-h)^2)/Rx^2) + ((y-k)^2)/Ry^2) <= 1
- // If the inequality is satisfied, then it is inside the ellipse; otherwise
- // it is outside the ellipse.
- // Both sides have been multiplied by Rx^2 * Ry^2 for efficiency reasons.
- return Math.pow(robotPose.getX() - m_center.getX(), 2)
- * Math.pow(m_radii.getY(), 2)
- + Math.pow(robotPose.getY() - m_center.getY(), 2)
- * Math.pow(m_radii.getX(), 2) <= Math.pow(m_radii.getX(), 2) * Math.pow(m_radii.getY(), 2);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/MaxVelocityConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/MaxVelocityConstraint.java
deleted file mode 100644
index 4d60623..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/MaxVelocityConstraint.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-
-/**
- * Represents a constraint that enforces a max velocity. This can be composed with the
- * {@link EllipticalRegionConstraint} or {@link RectangularRegionConstraint} to enforce
- * a max velocity in a region.
- */
-public class MaxVelocityConstraint implements TrajectoryConstraint {
- private final double m_maxVelocity;
-
- /**
- * Constructs a new MaxVelocityConstraint.
- *
- * @param maxVelocityMetersPerSecond The max velocity.
- */
- public MaxVelocityConstraint(double maxVelocityMetersPerSecond) {
- m_maxVelocity = maxVelocityMetersPerSecond;
- }
-
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- return m_maxVelocity;
- }
-
- @Override
- public TrajectoryConstraint.MinMax getMinMaxAccelerationMetersPerSecondSq(
- Pose2d poseMeters, double curvatureRadPerMeter, double velocityMetersPerSecond) {
- return new MinMax();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/MecanumDriveKinematicsConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/MecanumDriveKinematicsConstraint.java
deleted file mode 100644
index 6758d3d..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/MecanumDriveKinematicsConstraint.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.kinematics.ChassisSpeeds;
-import edu.wpi.first.wpilibj.kinematics.MecanumDriveKinematics;
-
-/**
- * A class that enforces constraints on the mecanum drive kinematics.
- * This can be used to ensure that the trajectory is constructed so that the
- * commanded velocities for all 4 wheels of the drivetrain stay below a certain
- * limit.
- */
-public class MecanumDriveKinematicsConstraint implements TrajectoryConstraint {
- private final double m_maxSpeedMetersPerSecond;
- private final MecanumDriveKinematics m_kinematics;
-
- /**
- * Constructs a mecanum drive dynamics constraint.
- *
- * @param maxSpeedMetersPerSecond The max speed that a side of the robot can travel at.
- */
- public MecanumDriveKinematicsConstraint(final MecanumDriveKinematics kinematics,
- double maxSpeedMetersPerSecond) {
- m_maxSpeedMetersPerSecond = maxSpeedMetersPerSecond;
- m_kinematics = kinematics;
- }
-
-
- /**
- * Returns the max velocity given the current pose and curvature.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
- * constraints are applied.
- * @return The absolute maximum velocity.
- */
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- // Represents the velocity of the chassis in the x direction
- var xdVelocity = velocityMetersPerSecond * poseMeters.getRotation().getCos();
-
- // Represents the velocity of the chassis in the y direction
- var ydVelocity = velocityMetersPerSecond * poseMeters.getRotation().getSin();
-
- // Create an object to represent the current chassis speeds.
- var chassisSpeeds = new ChassisSpeeds(xdVelocity,
- ydVelocity, velocityMetersPerSecond * curvatureRadPerMeter);
-
- // Get the wheel speeds and normalize them to within the max velocity.
- var wheelSpeeds = m_kinematics.toWheelSpeeds(chassisSpeeds);
- wheelSpeeds.normalize(m_maxSpeedMetersPerSecond);
-
- // Convert normalized wheel speeds back to chassis speeds
- var normSpeeds = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- // Return the new linear chassis speed.
- return Math.hypot(normSpeeds.vxMetersPerSecond, normSpeeds.vyMetersPerSecond);
- }
-
- /**
- * Returns the minimum and maximum allowable acceleration for the trajectory
- * given pose, curvature, and speed.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The speed at the current point in the trajectory.
- * @return The min and max acceleration bounds.
- */
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- return new MinMax();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/RectangularRegionConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/RectangularRegionConstraint.java
deleted file mode 100644
index c25c74c..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/RectangularRegionConstraint.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
-/**
- * Enforces a particular constraint only within a rectangular region.
- */
-public class RectangularRegionConstraint implements TrajectoryConstraint {
- private final Translation2d m_bottomLeftPoint;
- private final Translation2d m_topRightPoint;
- private final TrajectoryConstraint m_constraint;
-
- /**
- * Constructs a new RectangularRegionConstraint.
- *
- * @param bottomLeftPoint The bottom left point of the rectangular region in which to
- * enforce the constraint.
- * @param topRightPoint The top right point of the rectangular region in which to enforce
- * the constraint.
- * @param constraint The constraint to enforce when the robot is within the region.
- */
- public RectangularRegionConstraint(Translation2d bottomLeftPoint, Translation2d topRightPoint,
- TrajectoryConstraint constraint) {
- m_bottomLeftPoint = bottomLeftPoint;
- m_topRightPoint = topRightPoint;
- m_constraint = constraint;
- }
-
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- if (isPoseInRegion(poseMeters)) {
- return m_constraint.getMaxVelocityMetersPerSecond(poseMeters, curvatureRadPerMeter,
- velocityMetersPerSecond);
- } else {
- return Double.POSITIVE_INFINITY;
- }
- }
-
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- if (isPoseInRegion(poseMeters)) {
- return m_constraint.getMinMaxAccelerationMetersPerSecondSq(poseMeters,
- curvatureRadPerMeter, velocityMetersPerSecond);
- } else {
- return new MinMax();
- }
- }
-
- /**
- * Returns whether the specified robot pose is within the region that the constraint
- * is enforced in.
- *
- * @param robotPose The robot pose.
- * @return Whether the robot pose is within the constraint region.
- */
- public boolean isPoseInRegion(Pose2d robotPose) {
- return robotPose.getX() >= m_bottomLeftPoint.getX()
- && robotPose.getX() <= m_topRightPoint.getX()
- && robotPose.getY() >= m_bottomLeftPoint.getY()
- && robotPose.getY() <= m_topRightPoint.getY();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/SwerveDriveKinematicsConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/SwerveDriveKinematicsConstraint.java
deleted file mode 100644
index 693bfd5..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/SwerveDriveKinematicsConstraint.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.kinematics.ChassisSpeeds;
-import edu.wpi.first.wpilibj.kinematics.SwerveDriveKinematics;
-
-/**
- * A class that enforces constraints on the swerve drive kinematics.
- * This can be used to ensure that the trajectory is constructed so that the
- * commanded velocities for all 4 wheels of the drivetrain stay below a certain
- * limit.
- */
-public class SwerveDriveKinematicsConstraint implements TrajectoryConstraint {
- private final double m_maxSpeedMetersPerSecond;
- private final SwerveDriveKinematics m_kinematics;
-
- /**
- * Constructs a swerve drive dynamics constraint.
- *
- * @param maxSpeedMetersPerSecond The max speed that a side of the robot can travel at.
- */
- public SwerveDriveKinematicsConstraint(final SwerveDriveKinematics kinematics,
- double maxSpeedMetersPerSecond) {
- m_maxSpeedMetersPerSecond = maxSpeedMetersPerSecond;
- m_kinematics = kinematics;
- }
-
-
- /**
- * Returns the max velocity given the current pose and curvature.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
- * constraints are applied.
- * @return The absolute maximum velocity.
- */
- @Override
- public double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- // Represents the velocity of the chassis in the x direction
- var xdVelocity = velocityMetersPerSecond * poseMeters.getRotation().getCos();
-
- // Represents the velocity of the chassis in the y direction
- var ydVelocity = velocityMetersPerSecond * poseMeters.getRotation().getSin();
-
- // Create an object to represent the current chassis speeds.
- var chassisSpeeds = new ChassisSpeeds(xdVelocity,
- ydVelocity, velocityMetersPerSecond * curvatureRadPerMeter);
-
- // Get the wheel speeds and normalize them to within the max velocity.
- var wheelSpeeds = m_kinematics.toSwerveModuleStates(chassisSpeeds);
- SwerveDriveKinematics.normalizeWheelSpeeds(wheelSpeeds, m_maxSpeedMetersPerSecond);
-
- // Convert normalized wheel speeds back to chassis speeds
- var normSpeeds = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- // Return the new linear chassis speed.
- return Math.hypot(normSpeeds.vxMetersPerSecond, normSpeeds.vyMetersPerSecond);
- }
-
- /**
- * Returns the minimum and maximum allowable acceleration for the trajectory
- * given pose, curvature, and speed.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The speed at the current point in the trajectory.
- * @return The min and max acceleration bounds.
- */
- @Override
- public MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters,
- double curvatureRadPerMeter,
- double velocityMetersPerSecond) {
- return new MinMax();
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/TrajectoryConstraint.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/TrajectoryConstraint.java
deleted file mode 100644
index 5962404..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpilibj/trajectory/constraint/TrajectoryConstraint.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory.constraint;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-
-/**
- * An interface for defining user-defined velocity and acceleration constraints
- * while generating trajectories.
- */
-public interface TrajectoryConstraint {
-
- /**
- * Returns the max velocity given the current pose and curvature.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The velocity at the current point in the trajectory before
- * constraints are applied.
- * @return The absolute maximum velocity.
- */
- double getMaxVelocityMetersPerSecond(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond);
-
- /**
- * Returns the minimum and maximum allowable acceleration for the trajectory
- * given pose, curvature, and speed.
- *
- * @param poseMeters The pose at the current point in the trajectory.
- * @param curvatureRadPerMeter The curvature at the current point in the trajectory.
- * @param velocityMetersPerSecond The speed at the current point in the trajectory.
- * @return The min and max acceleration bounds.
- */
- MinMax getMinMaxAccelerationMetersPerSecondSq(Pose2d poseMeters, double curvatureRadPerMeter,
- double velocityMetersPerSecond);
-
- /**
- * Represents a minimum and maximum acceleration.
- */
- @SuppressWarnings("MemberName")
- class MinMax {
- public double minAccelerationMetersPerSecondSq = -Double.MAX_VALUE;
- public double maxAccelerationMetersPerSecondSq = +Double.MAX_VALUE;
-
- /**
- * Constructs a MinMax.
- *
- * @param minAccelerationMetersPerSecondSq The minimum acceleration.
- * @param maxAccelerationMetersPerSecondSq The maximum acceleration.
- */
- public MinMax(double minAccelerationMetersPerSecondSq,
- double maxAccelerationMetersPerSecondSq) {
- this.minAccelerationMetersPerSecondSq = minAccelerationMetersPerSecondSq;
- this.maxAccelerationMetersPerSecondSq = maxAccelerationMetersPerSecondSq;
- }
-
- /**
- * Constructs a MinMax with default values.
- */
- public MinMax() {
- }
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java b/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java
deleted file mode 100644
index d8b20a3..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-import java.util.Objects;
-
-import org.ejml.simple.SimpleMatrix;
-
-/**
- * A class for constructing arbitrary RxC matrices.
- *
- * @param <R> The number of rows of the desired matrix.
- * @param <C> The number of columns of the desired matrix.
- */
-public class MatBuilder<R extends Num, C extends Num> {
- final Nat<R> m_rows;
- final Nat<C> m_cols;
-
- /**
- * Fills the matrix with the given data, encoded in row major form.
- * (The matrix is filled row by row, left to right with the given data).
- *
- * @param data The data to fill the matrix with.
- * @return The constructed matrix.
- */
- @SuppressWarnings("LineLength")
- public final Matrix<R, C> fill(double... data) {
- if (Objects.requireNonNull(data).length != this.m_rows.getNum() * this.m_cols.getNum()) {
- throw new IllegalArgumentException("Invalid matrix data provided. Wanted " + this.m_rows.getNum()
- + " x " + this.m_cols.getNum() + " matrix, but got " + data.length + " elements");
- } else {
- return new Matrix<>(new SimpleMatrix(this.m_rows.getNum(), this.m_cols.getNum(), true, data));
- }
- }
-
- /**
- * Creates a new {@link MatBuilder} with the given dimensions.
- * @param rows The number of rows of the matrix.
- * @param cols The number of columns of the matrix.
- */
- public MatBuilder(Nat<R> rows, Nat<C> cols) {
- this.m_rows = Objects.requireNonNull(rows);
- this.m_cols = Objects.requireNonNull(cols);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MathUtil.java b/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MathUtil.java
deleted file mode 100644
index bd03f6b..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/MathUtil.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-public final class MathUtil {
- private MathUtil() {
- throw new AssertionError("utility class");
- }
-
- /**
- * Returns value clamped between low and high boundaries.
- *
- * @param value Value to clamp.
- * @param low The lower boundary to which to clamp value.
- * @param high The higher boundary to which to clamp value.
- */
- public static int clamp(int value, int low, int high) {
- return Math.max(low, Math.min(value, high));
- }
-
- /**
- * Returns value clamped between low and high boundaries.
- *
- * @param value Value to clamp.
- * @param low The lower boundary to which to clamp value.
- * @param high The higher boundary to which to clamp value.
- */
- public static double clamp(double value, double low, double high) {
- return Math.max(low, Math.min(value, high));
- }
-
- /**
- * Constrains theta to within the range (-pi, pi].
- *
- * @param theta The angle to normalize.
- * @return The normalized angle.
- */
- @SuppressWarnings("LocalVariableName")
- public static double normalizeAngle(double theta) {
- // Constraint theta to within (-3pi, pi)
- int nPiPos = (int) ((theta + Math.PI) / 2.0 / Math.PI);
- theta -= nPiPos * 2.0 * Math.PI;
-
- // Cut off the bottom half of the above range to constrain within
- // (-pi, pi]
- int nPiNeg = (int) ((theta - Math.PI) / 2.0 / Math.PI);
- theta -= nPiNeg * 2.0 * Math.PI;
-
- return theta;
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Num.java b/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Num.java
deleted file mode 100644
index 0b8a81f..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Num.java
+++ /dev/null
@@ -1,20 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-/**
- * A number expressed as a java class.
- */
-public abstract class Num {
- /**
- * The number this is backing.
- *
- * @return The number represented by this class.
- */
- public abstract int getNum();
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Pair.java b/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Pair.java
deleted file mode 100644
index eafbcba..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Pair.java
+++ /dev/null
@@ -1,31 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-public class Pair<A, B> {
- private final A m_first;
- private final B m_second;
-
- public Pair(A first, B second) {
- m_first = first;
- m_second = second;
- }
-
- public A getFirst() {
- return m_first;
- }
-
- public B getSecond() {
- return m_second;
- }
-
- @SuppressWarnings("ParameterName")
- public static <A, B> Pair<A, B> of(A a, B b) {
- return new Pair<>(a, b);
- }
-}
diff --git a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Vector.java b/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Vector.java
deleted file mode 100644
index 04f46ba..0000000
--- a/wpimath/src/main/java/edu/wpi/first/wpiutil/math/Vector.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-import org.ejml.simple.SimpleMatrix;
-
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-/**
- * A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices.
- *
- * <p>This class is intended to be used alongside the state space library.
- *
- * @param <R> The number of rows in this matrix.
- */
-public class Vector<R extends Num> extends Matrix<R, N1> {
-
- /**
- * Constructs an empty zero vector of the given dimensions.
- *
- * @param rows The number of rows of the vector.
- */
- public Vector(Nat<R> rows) {
- super(rows, Nat.N1());
- }
-
- /**
- * Constructs a new {@link Vector} with the given storage.
- * Caller should make sure that the provided generic bounds match
- * the shape of the provided {@link Vector}.
- *
- * <p>NOTE:It is not recommended to use this constructor unless the
- * {@link SimpleMatrix} API is absolutely necessary due to the desired
- * function not being accessible through the {@link Vector} wrapper.
- *
- * @param storage The {@link SimpleMatrix} to back this vector.
- */
- public Vector(SimpleMatrix storage) {
- super(storage);
- }
-
- /**
- * Constructs a new vector with the storage of the supplied matrix.
- *
- * @param other The {@link Vector} to copy the storage of.
- */
- public Vector(Matrix<R, N1> other) {
- super(other);
- }
-
- @Override
- public Vector<R> times(double value) {
- return new Vector<>(this.m_storage.scale(value));
- }
-
- @Override
- public Vector<R> div(int value) {
- return new Vector<>(this.m_storage.divide(value));
- }
-
- @Override
- public Vector<R> div(double value) {
- return new Vector<>(this.m_storage.divide(value));
- }
-}
diff --git a/wpimath/src/main/native/cpp/MathShared.cpp b/wpimath/src/main/native/cpp/MathShared.cpp
index 8a64f2e..5252e87 100644
--- a/wpimath/src/main/native/cpp/MathShared.cpp
+++ b/wpimath/src/main/native/cpp/MathShared.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "wpimath/MathShared.h"
@@ -14,7 +11,9 @@
namespace {
class DefaultMathShared : public MathShared {
public:
- void ReportError(const wpi::Twine& error) override {}
+ void ReportErrorV(fmt::string_view format, fmt::format_args args) override {}
+ void ReportWarningV(fmt::string_view format, fmt::format_args args) override {
+ }
void ReportUsage(MathUsageId id, int count) override {}
};
} // namespace
@@ -24,7 +23,9 @@
MathShared& MathSharedStore::GetMathShared() {
std::scoped_lock lock(setLock);
- if (!mathShared) mathShared = std::make_unique<DefaultMathShared>();
+ if (!mathShared) {
+ mathShared = std::make_unique<DefaultMathShared>();
+ }
return *mathShared;
}
diff --git a/wpimath/src/main/native/cpp/MathUtil.cpp b/wpimath/src/main/native/cpp/MathUtil.cpp
new file mode 100644
index 0000000..19cd214
--- /dev/null
+++ b/wpimath/src/main/native/cpp/MathUtil.cpp
@@ -0,0 +1,23 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/MathUtil.h"
+
+#include <cmath>
+
+namespace frc {
+
+double ApplyDeadband(double value, double deadband) {
+ if (std::abs(value) > deadband) {
+ if (value > 0.0) {
+ return (value - deadband) / (1.0 - deadband);
+ } else {
+ return (value + deadband) / (1.0 - deadband);
+ }
+ } else {
+ return 0.0;
+ }
+}
+
+} // namespace frc
diff --git a/wpimath/src/main/native/cpp/StateSpaceUtil.cpp b/wpimath/src/main/native/cpp/StateSpaceUtil.cpp
index d828f30..8f1145d 100644
--- a/wpimath/src/main/native/cpp/StateSpaceUtil.cpp
+++ b/wpimath/src/main/native/cpp/StateSpaceUtil.cpp
@@ -1,14 +1,23 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/StateSpaceUtil.h"
namespace frc {
+Eigen::Vector<double, 3> PoseTo3dVector(const Pose2d& pose) {
+ return Eigen::Vector<double, 3>{pose.Translation().X().value(),
+ pose.Translation().Y().value(),
+ pose.Rotation().Radians().value()};
+}
+
+Eigen::Vector<double, 4> PoseTo4dVector(const Pose2d& pose) {
+ return Eigen::Vector<double, 4>{pose.Translation().X().value(),
+ pose.Translation().Y().value(),
+ pose.Rotation().Cos(), pose.Rotation().Sin()};
+}
+
template <>
bool IsStabilizable<1, 1>(const Eigen::Matrix<double, 1, 1>& A,
const Eigen::Matrix<double, 1, 1>& B) {
@@ -21,9 +30,9 @@
return detail::IsStabilizableImpl<2, 1>(A, B);
}
-Eigen::Matrix<double, 3, 1> PoseToVector(const Pose2d& pose) {
- return frc::MakeMatrix<3, 1>(pose.X().to<double>(), pose.Y().to<double>(),
- pose.Rotation().Radians().to<double>());
+Eigen::Vector<double, 3> PoseToVector(const Pose2d& pose) {
+ return Eigen::Vector<double, 3>{pose.X().value(), pose.Y().value(),
+ pose.Rotation().Radians().value()};
}
} // namespace frc
diff --git a/wpimath/src/main/native/cpp/controller/HolonomicDriveController.cpp b/wpimath/src/main/native/cpp/controller/HolonomicDriveController.cpp
new file mode 100644
index 0000000..23b22cd
--- /dev/null
+++ b/wpimath/src/main/native/cpp/controller/HolonomicDriveController.cpp
@@ -0,0 +1,78 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/controller/HolonomicDriveController.h"
+
+#include <utility>
+
+#include "units/angular_velocity.h"
+
+using namespace frc;
+
+HolonomicDriveController::HolonomicDriveController(
+ frc2::PIDController xController, frc2::PIDController yController,
+ ProfiledPIDController<units::radian> thetaController)
+ : m_xController(std::move(xController)),
+ m_yController(std::move(yController)),
+ m_thetaController(std::move(thetaController)) {}
+
+bool HolonomicDriveController::AtReference() const {
+ const auto& eTranslate = m_poseError.Translation();
+ const auto& eRotate = m_rotationError;
+ const auto& tolTranslate = m_poseTolerance.Translation();
+ const auto& tolRotate = m_poseTolerance.Rotation();
+ return units::math::abs(eTranslate.X()) < tolTranslate.X() &&
+ units::math::abs(eTranslate.Y()) < tolTranslate.Y() &&
+ units::math::abs(eRotate.Radians()) < tolRotate.Radians();
+}
+
+void HolonomicDriveController::SetTolerance(const Pose2d& tolerance) {
+ m_poseTolerance = tolerance;
+}
+
+ChassisSpeeds HolonomicDriveController::Calculate(
+ const Pose2d& currentPose, const Pose2d& poseRef,
+ units::meters_per_second_t linearVelocityRef, const Rotation2d& angleRef) {
+ // If this is the first run, then we need to reset the theta controller to the
+ // current pose's heading.
+ if (m_firstRun) {
+ m_thetaController.Reset(currentPose.Rotation().Radians());
+ m_firstRun = false;
+ }
+
+ // Calculate feedforward velocities (field-relative)
+ auto xFF = linearVelocityRef * poseRef.Rotation().Cos();
+ auto yFF = linearVelocityRef * poseRef.Rotation().Sin();
+ auto thetaFF = units::radians_per_second_t(m_thetaController.Calculate(
+ currentPose.Rotation().Radians(), angleRef.Radians()));
+
+ m_poseError = poseRef.RelativeTo(currentPose);
+ m_rotationError = angleRef - currentPose.Rotation();
+
+ if (!m_enabled) {
+ return ChassisSpeeds::FromFieldRelativeSpeeds(xFF, yFF, thetaFF,
+ currentPose.Rotation());
+ }
+
+ // Calculate feedback velocities (based on position error).
+ auto xFeedback = units::meters_per_second_t(
+ m_xController.Calculate(currentPose.X().value(), poseRef.X().value()));
+ auto yFeedback = units::meters_per_second_t(
+ m_yController.Calculate(currentPose.Y().value(), poseRef.Y().value()));
+
+ // Return next output.
+ return ChassisSpeeds::FromFieldRelativeSpeeds(
+ xFF + xFeedback, yFF + yFeedback, thetaFF, currentPose.Rotation());
+}
+
+ChassisSpeeds HolonomicDriveController::Calculate(
+ const Pose2d& currentPose, const Trajectory::State& desiredState,
+ const Rotation2d& angleRef) {
+ return Calculate(currentPose, desiredState.pose, desiredState.velocity,
+ angleRef);
+}
+
+void HolonomicDriveController::SetEnabled(bool enabled) {
+ m_enabled = enabled;
+}
diff --git a/wpimath/src/main/native/cpp/controller/LinearQuadraticRegulator.cpp b/wpimath/src/main/native/cpp/controller/LinearQuadraticRegulator.cpp
index ae58440..4d2fbe9 100644
--- a/wpimath/src/main/native/cpp/controller/LinearQuadraticRegulator.cpp
+++ b/wpimath/src/main/native/cpp/controller/LinearQuadraticRegulator.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/controller/LinearQuadraticRegulator.h"
@@ -11,7 +8,7 @@
LinearQuadraticRegulator<1, 1>::LinearQuadraticRegulator(
const Eigen::Matrix<double, 1, 1>& A, const Eigen::Matrix<double, 1, 1>& B,
- const std::array<double, 1>& Qelems, const std::array<double, 1>& Relems,
+ const wpi::array<double, 1>& Qelems, const wpi::array<double, 1>& Relems,
units::second_t dt)
: LinearQuadraticRegulator(A, B, MakeCostMatrix(Qelems),
MakeCostMatrix(Relems), dt) {}
@@ -22,9 +19,15 @@
units::second_t dt)
: detail::LinearQuadraticRegulatorImpl<1, 1>(A, B, Q, R, dt) {}
+LinearQuadraticRegulator<1, 1>::LinearQuadraticRegulator(
+ const Eigen::Matrix<double, 1, 1>& A, const Eigen::Matrix<double, 1, 1>& B,
+ const Eigen::Matrix<double, 1, 1>& Q, const Eigen::Matrix<double, 1, 1>& R,
+ const Eigen::Matrix<double, 1, 1>& N, units::second_t dt)
+ : detail::LinearQuadraticRegulatorImpl<1, 1>(A, B, Q, R, N, dt) {}
+
LinearQuadraticRegulator<2, 1>::LinearQuadraticRegulator(
const Eigen::Matrix<double, 2, 2>& A, const Eigen::Matrix<double, 2, 1>& B,
- const std::array<double, 2>& Qelems, const std::array<double, 1>& Relems,
+ const wpi::array<double, 2>& Qelems, const wpi::array<double, 1>& Relems,
units::second_t dt)
: LinearQuadraticRegulator(A, B, MakeCostMatrix(Qelems),
MakeCostMatrix(Relems), dt) {}
@@ -35,4 +38,29 @@
units::second_t dt)
: detail::LinearQuadraticRegulatorImpl<2, 1>(A, B, Q, R, dt) {}
+LinearQuadraticRegulator<2, 1>::LinearQuadraticRegulator(
+ const Eigen::Matrix<double, 2, 2>& A, const Eigen::Matrix<double, 2, 1>& B,
+ const Eigen::Matrix<double, 2, 2>& Q, const Eigen::Matrix<double, 1, 1>& R,
+ const Eigen::Matrix<double, 2, 1>& N, units::second_t dt)
+ : detail::LinearQuadraticRegulatorImpl<2, 1>(A, B, Q, R, N, dt) {}
+
+LinearQuadraticRegulator<2, 2>::LinearQuadraticRegulator(
+ const Eigen::Matrix<double, 2, 2>& A, const Eigen::Matrix<double, 2, 2>& B,
+ const wpi::array<double, 2>& Qelems, const wpi::array<double, 2>& Relems,
+ units::second_t dt)
+ : LinearQuadraticRegulator(A, B, MakeCostMatrix(Qelems),
+ MakeCostMatrix(Relems), dt) {}
+
+LinearQuadraticRegulator<2, 2>::LinearQuadraticRegulator(
+ const Eigen::Matrix<double, 2, 2>& A, const Eigen::Matrix<double, 2, 2>& B,
+ const Eigen::Matrix<double, 2, 2>& Q, const Eigen::Matrix<double, 2, 2>& R,
+ units::second_t dt)
+ : detail::LinearQuadraticRegulatorImpl<2, 2>(A, B, Q, R, dt) {}
+
+LinearQuadraticRegulator<2, 2>::LinearQuadraticRegulator(
+ const Eigen::Matrix<double, 2, 2>& A, const Eigen::Matrix<double, 2, 2>& B,
+ const Eigen::Matrix<double, 2, 2>& Q, const Eigen::Matrix<double, 2, 2>& R,
+ const Eigen::Matrix<double, 2, 2>& N, units::second_t dt)
+ : detail::LinearQuadraticRegulatorImpl<2, 2>(A, B, Q, R, N, dt) {}
+
} // namespace frc
diff --git a/wpimath/src/main/native/cpp/controller/PIDController.cpp b/wpimath/src/main/native/cpp/controller/PIDController.cpp
new file mode 100644
index 0000000..34af2aa
--- /dev/null
+++ b/wpimath/src/main/native/cpp/controller/PIDController.cpp
@@ -0,0 +1,175 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/controller/PIDController.h"
+
+#include <algorithm>
+#include <cmath>
+
+#include <wpi/sendable/SendableBuilder.h>
+#include <wpi/sendable/SendableRegistry.h>
+
+#include "frc/MathUtil.h"
+#include "wpimath/MathShared.h"
+
+using namespace frc2;
+
+PIDController::PIDController(double Kp, double Ki, double Kd,
+ units::second_t period)
+ : m_Kp(Kp), m_Ki(Ki), m_Kd(Kd), m_period(period) {
+ if (period <= 0_s) {
+ wpi::math::MathSharedStore::ReportError(
+ "Controller period must be a non-zero positive number, got {}!",
+ period.value());
+ m_period = 20_ms;
+ wpi::math::MathSharedStore::ReportWarning(
+ "{}", "Controller period defaulted to 20ms.");
+ }
+ static int instances = 0;
+ instances++;
+
+ wpi::math::MathSharedStore::ReportUsage(
+ wpi::math::MathUsageId::kController_PIDController2, instances);
+ wpi::SendableRegistry::Add(this, "PIDController", instances);
+}
+
+void PIDController::SetPID(double Kp, double Ki, double Kd) {
+ m_Kp = Kp;
+ m_Ki = Ki;
+ m_Kd = Kd;
+}
+
+void PIDController::SetP(double Kp) {
+ m_Kp = Kp;
+}
+
+void PIDController::SetI(double Ki) {
+ m_Ki = Ki;
+}
+
+void PIDController::SetD(double Kd) {
+ m_Kd = Kd;
+}
+
+double PIDController::GetP() const {
+ return m_Kp;
+}
+
+double PIDController::GetI() const {
+ return m_Ki;
+}
+
+double PIDController::GetD() const {
+ return m_Kd;
+}
+
+units::second_t PIDController::GetPeriod() const {
+ return units::second_t(m_period);
+}
+
+void PIDController::SetSetpoint(double setpoint) {
+ m_setpoint = setpoint;
+}
+
+double PIDController::GetSetpoint() const {
+ return m_setpoint;
+}
+
+bool PIDController::AtSetpoint() const {
+ double positionError;
+ if (m_continuous) {
+ double errorBound = (m_maximumInput - m_minimumInput) / 2.0;
+ positionError =
+ frc::InputModulus(m_setpoint - m_measurement, -errorBound, errorBound);
+ } else {
+ positionError = m_setpoint - m_measurement;
+ }
+
+ double velocityError = (positionError - m_prevError) / m_period.value();
+
+ return std::abs(positionError) < m_positionTolerance &&
+ std::abs(velocityError) < m_velocityTolerance;
+}
+
+void PIDController::EnableContinuousInput(double minimumInput,
+ double maximumInput) {
+ m_continuous = true;
+ m_minimumInput = minimumInput;
+ m_maximumInput = maximumInput;
+}
+
+void PIDController::DisableContinuousInput() {
+ m_continuous = false;
+}
+
+bool PIDController::IsContinuousInputEnabled() const {
+ return m_continuous;
+}
+
+void PIDController::SetIntegratorRange(double minimumIntegral,
+ double maximumIntegral) {
+ m_minimumIntegral = minimumIntegral;
+ m_maximumIntegral = maximumIntegral;
+}
+
+void PIDController::SetTolerance(double positionTolerance,
+ double velocityTolerance) {
+ m_positionTolerance = positionTolerance;
+ m_velocityTolerance = velocityTolerance;
+}
+
+double PIDController::GetPositionError() const {
+ return m_positionError;
+}
+
+double PIDController::GetVelocityError() const {
+ return m_velocityError;
+}
+
+double PIDController::Calculate(double measurement) {
+ m_measurement = measurement;
+ m_prevError = m_positionError;
+
+ if (m_continuous) {
+ double errorBound = (m_maximumInput - m_minimumInput) / 2.0;
+ m_positionError =
+ frc::InputModulus(m_setpoint - m_measurement, -errorBound, errorBound);
+ } else {
+ m_positionError = m_setpoint - measurement;
+ }
+
+ m_velocityError = (m_positionError - m_prevError) / m_period.value();
+
+ if (m_Ki != 0) {
+ m_totalError =
+ std::clamp(m_totalError + m_positionError * m_period.value(),
+ m_minimumIntegral / m_Ki, m_maximumIntegral / m_Ki);
+ }
+
+ return m_Kp * m_positionError + m_Ki * m_totalError + m_Kd * m_velocityError;
+}
+
+double PIDController::Calculate(double measurement, double setpoint) {
+ // Set setpoint to provided value
+ SetSetpoint(setpoint);
+ return Calculate(measurement);
+}
+
+void PIDController::Reset() {
+ m_prevError = 0;
+ m_totalError = 0;
+}
+
+void PIDController::InitSendable(wpi::SendableBuilder& builder) {
+ builder.SetSmartDashboardType("PIDController");
+ builder.AddDoubleProperty(
+ "p", [this] { return GetP(); }, [this](double value) { SetP(value); });
+ builder.AddDoubleProperty(
+ "i", [this] { return GetI(); }, [this](double value) { SetI(value); });
+ builder.AddDoubleProperty(
+ "d", [this] { return GetD(); }, [this](double value) { SetD(value); });
+ builder.AddDoubleProperty(
+ "setpoint", [this] { return GetSetpoint(); },
+ [this](double value) { SetSetpoint(value); });
+}
diff --git a/wpimath/src/main/native/cpp/controller/ProfiledPIDController.cpp b/wpimath/src/main/native/cpp/controller/ProfiledPIDController.cpp
new file mode 100644
index 0000000..fa58427
--- /dev/null
+++ b/wpimath/src/main/native/cpp/controller/ProfiledPIDController.cpp
@@ -0,0 +1,12 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/controller/ProfiledPIDController.h"
+
+void frc::detail::ReportProfiledPIDController() {
+ static int instances = 0;
+ ++instances;
+ wpi::math::MathSharedStore::ReportUsage(
+ wpi::math::MathUsageId::kController_ProfiledPIDController, instances);
+}
diff --git a/wpimath/src/main/native/cpp/controller/RamseteController.cpp b/wpimath/src/main/native/cpp/controller/RamseteController.cpp
new file mode 100644
index 0000000..8aa16d8
--- /dev/null
+++ b/wpimath/src/main/native/cpp/controller/RamseteController.cpp
@@ -0,0 +1,77 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/controller/RamseteController.h"
+
+#include <cmath>
+
+#include "units/math.h"
+
+using namespace frc;
+
+/**
+ * Returns sin(x) / x.
+ *
+ * @param x Value of which to take sinc(x).
+ */
+static double Sinc(double x) {
+ if (std::abs(x) < 1e-9) {
+ return 1.0 - 1.0 / 6.0 * x * x;
+ } else {
+ return std::sin(x) / x;
+ }
+}
+
+RamseteController::RamseteController(double b, double zeta)
+ : m_b{b}, m_zeta{zeta} {}
+
+bool RamseteController::AtReference() const {
+ const auto& eTranslate = m_poseError.Translation();
+ const auto& eRotate = m_poseError.Rotation();
+ const auto& tolTranslate = m_poseTolerance.Translation();
+ const auto& tolRotate = m_poseTolerance.Rotation();
+ return units::math::abs(eTranslate.X()) < tolTranslate.X() &&
+ units::math::abs(eTranslate.Y()) < tolTranslate.Y() &&
+ units::math::abs(eRotate.Radians()) < tolRotate.Radians();
+}
+
+void RamseteController::SetTolerance(const Pose2d& poseTolerance) {
+ m_poseTolerance = poseTolerance;
+}
+
+ChassisSpeeds RamseteController::Calculate(
+ const Pose2d& currentPose, const Pose2d& poseRef,
+ units::meters_per_second_t linearVelocityRef,
+ units::radians_per_second_t angularVelocityRef) {
+ if (!m_enabled) {
+ return ChassisSpeeds{linearVelocityRef, 0_mps, angularVelocityRef};
+ }
+
+ m_poseError = poseRef.RelativeTo(currentPose);
+
+ // Aliases for equation readability
+ double eX = m_poseError.X().value();
+ double eY = m_poseError.Y().value();
+ double eTheta = m_poseError.Rotation().Radians().value();
+ double vRef = linearVelocityRef.value();
+ double omegaRef = angularVelocityRef.value();
+
+ double k =
+ 2.0 * m_zeta * std::sqrt(std::pow(omegaRef, 2) + m_b * std::pow(vRef, 2));
+
+ units::meters_per_second_t v{vRef * m_poseError.Rotation().Cos() + k * eX};
+ units::radians_per_second_t omega{omegaRef + k * eTheta +
+ m_b * vRef * Sinc(eTheta) * eY};
+ return ChassisSpeeds{v, 0_mps, omega};
+}
+
+ChassisSpeeds RamseteController::Calculate(
+ const Pose2d& currentPose, const Trajectory::State& desiredState) {
+ return Calculate(currentPose, desiredState.pose, desiredState.velocity,
+ desiredState.velocity * desiredState.curvature);
+}
+
+void RamseteController::SetEnabled(bool enabled) {
+ m_enabled = enabled;
+}
diff --git a/wpimath/src/main/native/cpp/drake/math/discrete_algebraic_riccati_equation.cpp b/wpimath/src/main/native/cpp/drake/math/discrete_algebraic_riccati_equation.cpp
index e80cadc..20ea2b7 100644
--- a/wpimath/src/main/native/cpp/drake/math/discrete_algebraic_riccati_equation.cpp
+++ b/wpimath/src/main/native/cpp/drake/math/discrete_algebraic_riccati_equation.cpp
@@ -1,14 +1,11 @@
#include "drake/math/discrete_algebraic_riccati_equation.h"
-#include "drake/common/drake_assert.h"
-#include "drake/common/drake_throw.h"
-#include "drake/common/is_approx_equal_abstol.h"
-
#include <Eigen/Eigenvalues>
#include <Eigen/QR>
-// This code has https://github.com/RobotLocomotion/drake/pull/11118 applied to
-// fix an infinite loop in reorder_eigen().
+#include "drake/common/drake_assert.h"
+#include "drake/common/drake_throw.h"
+#include "drake/common/is_approx_equal_abstol.h"
namespace drake {
namespace math {
@@ -385,12 +382,11 @@
* DiscreteAlgebraicRiccatiEquation function
* computes the unique stabilizing solution X to the discrete-time algebraic
* Riccati equation:
- * \f[
- * A'XA - X - A'XB(B'XB+R)^{-1}B'XA + Q = 0
- * \f]
*
- * @throws std::runtime_error if Q is not positive semi-definite.
- * @throws std::runtime_error if R is not positive definite.
+ * AᵀXA − X − AᵀXB(BᵀXB + R)⁻¹BᵀXA + Q = 0
+ *
+ * @throws std::exception if Q is not positive semi-definite.
+ * @throws std::exception if R is not positive definite.
*
* Based on the Schur Vector approach outlined in this paper:
* "On the Numerical Solution of the Discrete-Time Algebraic Riccati Equation"
@@ -399,9 +395,9 @@
*
* Note: When, for example, n = 100, m = 80, and entries of A, B, Q_half,
* R_half are sampled from standard normal distributions, where
- * Q = Q_half'*Q_half and similar for R, the absolute error of the solution
- * is 10^{-6}, while the absolute error of the solution computed by Matlab is
- * 10^{-8}.
+ * Q = Q_halfᵀ Q_half and similar for R, the absolute error of the solution
+ * is 10⁻⁶, while the absolute error of the solution computed by Matlab is
+ * 10⁻⁸.
*
* TODO(weiqiao.han): I may overwrite the RealQZ function to improve the
* accuracy, together with more thorough tests.
@@ -459,5 +455,21 @@
return X;
}
+Eigen::MatrixXd DiscreteAlgebraicRiccatiEquation(
+ const Eigen::Ref<const Eigen::MatrixXd>& A,
+ const Eigen::Ref<const Eigen::MatrixXd>& B,
+ const Eigen::Ref<const Eigen::MatrixXd>& Q,
+ const Eigen::Ref<const Eigen::MatrixXd>& R,
+ const Eigen::Ref<const Eigen::MatrixXd>& N) {
+ DRAKE_DEMAND(N.rows() == B.rows() && N.cols() == B.cols());
+
+ // This is a change of variables to make the DARE that includes Q, R, and N
+ // cost matrices fit the form of the DARE that includes only Q and R cost
+ // matrices.
+ Eigen::MatrixXd A2 = A - B * R.llt().solve(N.transpose());
+ Eigen::MatrixXd Q2 = Q - N * R.llt().solve(N.transpose());
+ return DiscreteAlgebraicRiccatiEquation(A2, B, Q2, R);
+}
+
} // namespace math
} // namespace drake
diff --git a/wpimath/src/main/native/cpp/estimator/DifferentialDrivePoseEstimator.cpp b/wpimath/src/main/native/cpp/estimator/DifferentialDrivePoseEstimator.cpp
new file mode 100644
index 0000000..c5ed7a1
--- /dev/null
+++ b/wpimath/src/main/native/cpp/estimator/DifferentialDrivePoseEstimator.cpp
@@ -0,0 +1,145 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/estimator/DifferentialDrivePoseEstimator.h"
+
+#include <wpi/timestamp.h>
+
+#include "frc/StateSpaceUtil.h"
+#include "frc/estimator/AngleStatistics.h"
+
+using namespace frc;
+
+DifferentialDrivePoseEstimator::DifferentialDrivePoseEstimator(
+ const Rotation2d& gyroAngle, const Pose2d& initialPose,
+ const wpi::array<double, 5>& stateStdDevs,
+ const wpi::array<double, 3>& localMeasurementStdDevs,
+ const wpi::array<double, 3>& visionMeasurmentStdDevs,
+ units::second_t nominalDt)
+ : m_observer(
+ &DifferentialDrivePoseEstimator::F,
+ [](const Eigen::Vector<double, 5>& x,
+ const Eigen::Vector<double, 3>& u) {
+ return Eigen::Vector<double, 3>{x(3, 0), x(4, 0), x(2, 0)};
+ },
+ stateStdDevs, localMeasurementStdDevs, frc::AngleMean<5, 5>(2),
+ frc::AngleMean<3, 5>(2), frc::AngleResidual<5>(2),
+ frc::AngleResidual<3>(2), frc::AngleAdd<5>(2), nominalDt),
+ m_nominalDt(nominalDt) {
+ SetVisionMeasurementStdDevs(visionMeasurmentStdDevs);
+
+ // Create correction mechanism for vision measurements.
+ m_visionCorrect = [&](const Eigen::Vector<double, 3>& u,
+ const Eigen::Vector<double, 3>& y) {
+ m_observer.Correct<3>(
+ u, y,
+ [](const Eigen::Vector<double, 5>& x, const Eigen::Vector<double, 3>&) {
+ return x.block<3, 1>(0, 0);
+ },
+ m_visionContR, frc::AngleMean<3, 5>(2), frc::AngleResidual<3>(2),
+ frc::AngleResidual<5>(2), frc::AngleAdd<5>(2));
+ };
+
+ m_gyroOffset = initialPose.Rotation() - gyroAngle;
+ m_previousAngle = initialPose.Rotation();
+ m_observer.SetXhat(FillStateVector(initialPose, 0_m, 0_m));
+}
+
+void DifferentialDrivePoseEstimator::SetVisionMeasurementStdDevs(
+ const wpi::array<double, 3>& visionMeasurmentStdDevs) {
+ // Create R (covariances) for vision measurements.
+ m_visionContR = frc::MakeCovMatrix(visionMeasurmentStdDevs);
+}
+
+void DifferentialDrivePoseEstimator::ResetPosition(
+ const Pose2d& pose, const Rotation2d& gyroAngle) {
+ // Reset state estimate and error covariance
+ m_observer.Reset();
+ m_latencyCompensator.Reset();
+
+ m_observer.SetXhat(FillStateVector(pose, 0_m, 0_m));
+
+ m_gyroOffset = GetEstimatedPosition().Rotation() - gyroAngle;
+ m_previousAngle = pose.Rotation();
+}
+
+Pose2d DifferentialDrivePoseEstimator::GetEstimatedPosition() const {
+ return Pose2d(units::meter_t(m_observer.Xhat(0)),
+ units::meter_t(m_observer.Xhat(1)),
+ Rotation2d(units::radian_t(m_observer.Xhat(2))));
+}
+
+void DifferentialDrivePoseEstimator::AddVisionMeasurement(
+ const Pose2d& visionRobotPose, units::second_t timestamp) {
+ m_latencyCompensator.ApplyPastGlobalMeasurement<3>(
+ &m_observer, m_nominalDt, PoseTo3dVector(visionRobotPose),
+ m_visionCorrect, timestamp);
+}
+
+Pose2d DifferentialDrivePoseEstimator::Update(
+ const Rotation2d& gyroAngle,
+ const DifferentialDriveWheelSpeeds& wheelSpeeds,
+ units::meter_t leftDistance, units::meter_t rightDistance) {
+ return UpdateWithTime(units::microsecond_t(wpi::Now()), gyroAngle,
+ wheelSpeeds, leftDistance, rightDistance);
+}
+
+Pose2d DifferentialDrivePoseEstimator::UpdateWithTime(
+ units::second_t currentTime, const Rotation2d& gyroAngle,
+ const DifferentialDriveWheelSpeeds& wheelSpeeds,
+ units::meter_t leftDistance, units::meter_t rightDistance) {
+ auto dt = m_prevTime >= 0_s ? currentTime - m_prevTime : m_nominalDt;
+ m_prevTime = currentTime;
+
+ auto angle = gyroAngle + m_gyroOffset;
+ auto omega = (gyroAngle - m_previousAngle).Radians() / dt;
+
+ auto u = Eigen::Vector<double, 3>{
+ (wheelSpeeds.left + wheelSpeeds.right).value() / 2.0, 0.0, omega.value()};
+
+ m_previousAngle = angle;
+
+ auto localY = Eigen::Vector<double, 3>{
+ leftDistance.value(), rightDistance.value(), angle.Radians().value()};
+
+ m_latencyCompensator.AddObserverState(m_observer, u, localY, currentTime);
+ m_observer.Predict(u, dt);
+ m_observer.Correct(u, localY);
+
+ return GetEstimatedPosition();
+}
+
+Eigen::Vector<double, 5> DifferentialDrivePoseEstimator::F(
+ const Eigen::Vector<double, 5>& x, const Eigen::Vector<double, 3>& u) {
+ // Apply a rotation matrix. Note that we do not add x because Runge-Kutta does
+ // that for us.
+ auto& theta = x(2);
+ Eigen::Matrix<double, 5, 5> toFieldRotation{
+ {std::cos(theta), -std::sin(theta), 0.0, 0.0, 0.0},
+ {std::sin(theta), std::cos(theta), 0.0, 0.0, 0.0},
+ {0.0, 0.0, 1.0, 0.0, 0.0},
+ {0.0, 0.0, 0.0, 1.0, 0.0},
+ {0.0, 0.0, 0.0, 0.0, 1.0}};
+ return toFieldRotation *
+ Eigen::Vector<double, 5>{u(0, 0), u(1, 0), u(2, 0), u(0, 0), u(1, 0)};
+}
+
+template <int Dim>
+wpi::array<double, Dim> DifferentialDrivePoseEstimator::StdDevMatrixToArray(
+ const Eigen::Vector<double, Dim>& stdDevs) {
+ wpi::array<double, Dim> array;
+ for (size_t i = 0; i < Dim; ++i) {
+ array[i] = stdDevs(i);
+ }
+ return array;
+}
+
+Eigen::Vector<double, 5> DifferentialDrivePoseEstimator::FillStateVector(
+ const Pose2d& pose, units::meter_t leftDistance,
+ units::meter_t rightDistance) {
+ return Eigen::Vector<double, 5>{pose.Translation().X().value(),
+ pose.Translation().Y().value(),
+ pose.Rotation().Radians().value(),
+ leftDistance.value(), rightDistance.value()};
+}
diff --git a/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp b/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp
index a1747ab..1209eae 100644
--- a/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp
+++ b/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp
@@ -1,23 +1,20 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/estimator/KalmanFilter.h"
namespace frc {
KalmanFilter<1, 1, 1>::KalmanFilter(
- LinearSystem<1, 1, 1>& plant, const std::array<double, 1>& stateStdDevs,
- const std::array<double, 1>& measurementStdDevs, units::second_t dt)
+ LinearSystem<1, 1, 1>& plant, const wpi::array<double, 1>& stateStdDevs,
+ const wpi::array<double, 1>& measurementStdDevs, units::second_t dt)
: detail::KalmanFilterImpl<1, 1, 1>{plant, stateStdDevs, measurementStdDevs,
dt} {}
KalmanFilter<2, 1, 1>::KalmanFilter(
- LinearSystem<2, 1, 1>& plant, const std::array<double, 2>& stateStdDevs,
- const std::array<double, 1>& measurementStdDevs, units::second_t dt)
+ LinearSystem<2, 1, 1>& plant, const wpi::array<double, 2>& stateStdDevs,
+ const wpi::array<double, 1>& measurementStdDevs, units::second_t dt)
: detail::KalmanFilterImpl<2, 1, 1>{plant, stateStdDevs, measurementStdDevs,
dt} {}
diff --git a/wpimath/src/main/native/cpp/estimator/MecanumDrivePoseEstimator.cpp b/wpimath/src/main/native/cpp/estimator/MecanumDrivePoseEstimator.cpp
new file mode 100644
index 0000000..9d93647
--- /dev/null
+++ b/wpimath/src/main/native/cpp/estimator/MecanumDrivePoseEstimator.cpp
@@ -0,0 +1,116 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/estimator/MecanumDrivePoseEstimator.h"
+
+#include <wpi/timestamp.h>
+
+#include "frc/StateSpaceUtil.h"
+#include "frc/estimator/AngleStatistics.h"
+
+using namespace frc;
+
+frc::MecanumDrivePoseEstimator::MecanumDrivePoseEstimator(
+ const Rotation2d& gyroAngle, const Pose2d& initialPose,
+ MecanumDriveKinematics kinematics,
+ const wpi::array<double, 3>& stateStdDevs,
+ const wpi::array<double, 1>& localMeasurementStdDevs,
+ const wpi::array<double, 3>& visionMeasurementStdDevs,
+ units::second_t nominalDt)
+ : m_observer(
+ [](const Eigen::Vector<double, 3>& x,
+ const Eigen::Vector<double, 3>& u) { return u; },
+ [](const Eigen::Vector<double, 3>& x,
+ const Eigen::Vector<double, 3>& u) { return x.block<1, 1>(2, 0); },
+ stateStdDevs, localMeasurementStdDevs, frc::AngleMean<3, 3>(2),
+ frc::AngleMean<1, 3>(0), frc::AngleResidual<3>(2),
+ frc::AngleResidual<1>(0), frc::AngleAdd<3>(2), nominalDt),
+ m_kinematics(kinematics),
+ m_nominalDt(nominalDt) {
+ SetVisionMeasurementStdDevs(visionMeasurementStdDevs);
+
+ // Create vision correction mechanism.
+ m_visionCorrect = [&](const Eigen::Vector<double, 3>& u,
+ const Eigen::Vector<double, 3>& y) {
+ m_observer.Correct<3>(
+ u, y,
+ [](const Eigen::Vector<double, 3>& x, const Eigen::Vector<double, 3>&) {
+ return x;
+ },
+ m_visionContR, frc::AngleMean<3, 3>(2), frc::AngleResidual<3>(2),
+ frc::AngleResidual<3>(2), frc::AngleAdd<3>(2));
+ };
+
+ // Set initial state.
+ m_observer.SetXhat(PoseTo3dVector(initialPose));
+
+ // Calculate offsets.
+ m_gyroOffset = initialPose.Rotation() - gyroAngle;
+ m_previousAngle = initialPose.Rotation();
+}
+
+void frc::MecanumDrivePoseEstimator::SetVisionMeasurementStdDevs(
+ const wpi::array<double, 3>& visionMeasurmentStdDevs) {
+ // Create R (covariances) for vision measurements.
+ m_visionContR = frc::MakeCovMatrix(visionMeasurmentStdDevs);
+}
+
+void frc::MecanumDrivePoseEstimator::ResetPosition(
+ const Pose2d& pose, const Rotation2d& gyroAngle) {
+ // Reset state estimate and error covariance
+ m_observer.Reset();
+ m_latencyCompensator.Reset();
+
+ m_observer.SetXhat(PoseTo3dVector(pose));
+
+ m_gyroOffset = pose.Rotation() - gyroAngle;
+ m_previousAngle = pose.Rotation();
+}
+
+Pose2d frc::MecanumDrivePoseEstimator::GetEstimatedPosition() const {
+ return Pose2d(m_observer.Xhat(0) * 1_m, m_observer.Xhat(1) * 1_m,
+ Rotation2d(units::radian_t{m_observer.Xhat(2)}));
+}
+
+void frc::MecanumDrivePoseEstimator::AddVisionMeasurement(
+ const Pose2d& visionRobotPose, units::second_t timestamp) {
+ m_latencyCompensator.ApplyPastGlobalMeasurement<3>(
+ &m_observer, m_nominalDt, PoseTo3dVector(visionRobotPose),
+ m_visionCorrect, timestamp);
+}
+
+Pose2d frc::MecanumDrivePoseEstimator::Update(
+ const Rotation2d& gyroAngle, const MecanumDriveWheelSpeeds& wheelSpeeds) {
+ return UpdateWithTime(units::microsecond_t(wpi::Now()), gyroAngle,
+ wheelSpeeds);
+}
+
+Pose2d frc::MecanumDrivePoseEstimator::UpdateWithTime(
+ units::second_t currentTime, const Rotation2d& gyroAngle,
+ const MecanumDriveWheelSpeeds& wheelSpeeds) {
+ auto dt = m_prevTime >= 0_s ? currentTime - m_prevTime : m_nominalDt;
+ m_prevTime = currentTime;
+
+ auto angle = gyroAngle + m_gyroOffset;
+ auto omega = (angle - m_previousAngle).Radians() / dt;
+
+ auto chassisSpeeds = m_kinematics.ToChassisSpeeds(wheelSpeeds);
+ auto fieldRelativeVelocities =
+ Translation2d(chassisSpeeds.vx * 1_s, chassisSpeeds.vy * 1_s)
+ .RotateBy(angle);
+
+ Eigen::Vector<double, 3> u{fieldRelativeVelocities.X().value(),
+ fieldRelativeVelocities.Y().value(),
+ omega.value()};
+
+ Eigen::Vector<double, 1> localY{angle.Radians().value()};
+ m_previousAngle = angle;
+
+ m_latencyCompensator.AddObserverState(m_observer, u, localY, currentTime);
+
+ m_observer.Predict(u, dt);
+ m_observer.Correct(u, localY);
+
+ return GetEstimatedPosition();
+}
diff --git a/wpimath/src/main/native/cpp/geometry/Pose2d.cpp b/wpimath/src/main/native/cpp/geometry/Pose2d.cpp
index 57dad44..b7176cd 100644
--- a/wpimath/src/main/native/cpp/geometry/Pose2d.cpp
+++ b/wpimath/src/main/native/cpp/geometry/Pose2d.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/geometry/Pose2d.h"
@@ -23,12 +20,6 @@
return TransformBy(other);
}
-Pose2d& Pose2d::operator+=(const Transform2d& other) {
- m_translation += other.Translation().RotateBy(m_rotation);
- m_rotation += other.Rotation();
- return *this;
-}
-
Transform2d Pose2d::operator-(const Pose2d& other) const {
const auto pose = this->RelativeTo(other);
return Transform2d(pose.Translation(), pose.Rotation());
@@ -55,7 +46,7 @@
Pose2d Pose2d::Exp(const Twist2d& twist) const {
const auto dx = twist.dx;
const auto dy = twist.dy;
- const auto dtheta = twist.dtheta.to<double>();
+ const auto dtheta = twist.dtheta.value();
const auto sinTheta = std::sin(dtheta);
const auto cosTheta = std::cos(dtheta);
@@ -77,7 +68,7 @@
Twist2d Pose2d::Log(const Pose2d& end) const {
const auto transform = end.RelativeTo(*this);
- const auto dtheta = transform.Rotation().Radians().to<double>();
+ const auto dtheta = transform.Rotation().Radians().value();
const auto halfDtheta = dtheta / 2.0;
const auto cosMinusOne = transform.Rotation().Cos() - 1;
diff --git a/wpimath/src/main/native/cpp/geometry/Rotation2d.cpp b/wpimath/src/main/native/cpp/geometry/Rotation2d.cpp
index 32a9b40..27af5ed 100644
--- a/wpimath/src/main/native/cpp/geometry/Rotation2d.cpp
+++ b/wpimath/src/main/native/cpp/geometry/Rotation2d.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/geometry/Rotation2d.h"
@@ -41,32 +38,20 @@
return RotateBy(other);
}
-Rotation2d& Rotation2d::operator+=(const Rotation2d& other) {
- double cos = Cos() * other.Cos() - Sin() * other.Sin();
- double sin = Cos() * other.Sin() + Sin() * other.Cos();
- m_cos = cos;
- m_sin = sin;
- m_value = units::radian_t(std::atan2(m_sin, m_cos));
- return *this;
-}
-
Rotation2d Rotation2d::operator-(const Rotation2d& other) const {
return *this + -other;
}
-Rotation2d& Rotation2d::operator-=(const Rotation2d& other) {
- *this += -other;
- return *this;
+Rotation2d Rotation2d::operator-() const {
+ return Rotation2d(-m_value);
}
-Rotation2d Rotation2d::operator-() const { return Rotation2d(-m_value); }
-
Rotation2d Rotation2d::operator*(double scalar) const {
return Rotation2d(m_value * scalar);
}
bool Rotation2d::operator==(const Rotation2d& other) const {
- return units::math::abs(m_value - other.m_value) < 1E-9_rad;
+ return std::hypot(m_cos - other.m_cos, m_sin - other.m_sin) < 1E-9;
}
bool Rotation2d::operator!=(const Rotation2d& other) const {
@@ -79,7 +64,7 @@
}
void frc::to_json(wpi::json& json, const Rotation2d& rotation) {
- json = wpi::json{{"radians", rotation.Radians().to<double>()}};
+ json = wpi::json{{"radians", rotation.Radians().value()}};
}
void frc::from_json(const wpi::json& json, Rotation2d& rotation) {
diff --git a/wpimath/src/main/native/cpp/geometry/Transform2d.cpp b/wpimath/src/main/native/cpp/geometry/Transform2d.cpp
index eb5e2ec..0808f35 100644
--- a/wpimath/src/main/native/cpp/geometry/Transform2d.cpp
+++ b/wpimath/src/main/native/cpp/geometry/Transform2d.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/geometry/Transform2d.h"
@@ -31,6 +28,10 @@
return Transform2d{(-Translation()).RotateBy(-Rotation()), -Rotation()};
}
+Transform2d Transform2d::operator+(const Transform2d& other) const {
+ return Transform2d{Pose2d{}, Pose2d{}.TransformBy(*this).TransformBy(other)};
+}
+
bool Transform2d::operator==(const Transform2d& other) const {
return m_translation == other.m_translation && m_rotation == other.m_rotation;
}
diff --git a/wpimath/src/main/native/cpp/geometry/Translation2d.cpp b/wpimath/src/main/native/cpp/geometry/Translation2d.cpp
index 6f4551c..5a30ec2 100644
--- a/wpimath/src/main/native/cpp/geometry/Translation2d.cpp
+++ b/wpimath/src/main/native/cpp/geometry/Translation2d.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/geometry/Translation2d.h"
@@ -36,33 +33,18 @@
return {X() + other.X(), Y() + other.Y()};
}
-Translation2d& Translation2d::operator+=(const Translation2d& other) {
- m_x += other.m_x;
- m_y += other.m_y;
- return *this;
-}
-
Translation2d Translation2d::operator-(const Translation2d& other) const {
return *this + -other;
}
-Translation2d& Translation2d::operator-=(const Translation2d& other) {
- *this += -other;
- return *this;
+Translation2d Translation2d::operator-() const {
+ return {-m_x, -m_y};
}
-Translation2d Translation2d::operator-() const { return {-m_x, -m_y}; }
-
Translation2d Translation2d::operator*(double scalar) const {
return {scalar * m_x, scalar * m_y};
}
-Translation2d& Translation2d::operator*=(double scalar) {
- m_x *= scalar;
- m_y *= scalar;
- return *this;
-}
-
Translation2d Translation2d::operator/(double scalar) const {
return *this * (1.0 / scalar);
}
@@ -76,14 +58,9 @@
return !operator==(other);
}
-Translation2d& Translation2d::operator/=(double scalar) {
- *this *= (1.0 / scalar);
- return *this;
-}
-
void frc::to_json(wpi::json& json, const Translation2d& translation) {
- json = wpi::json{{"x", translation.X().to<double>()},
- {"y", translation.Y().to<double>()}};
+ json =
+ wpi::json{{"x", translation.X().value()}, {"y", translation.Y().value()}};
}
void frc::from_json(const wpi::json& json, Translation2d& translation) {
diff --git a/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp b/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp
index 26d57e2..b036833 100644
--- a/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp
+++ b/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
+#include <exception>
+
#include <wpi/jni_util.h>
#include "Eigen/Core"
@@ -14,28 +13,38 @@
#include "Eigen/QR"
#include "drake/math/discrete_algebraic_riccati_equation.h"
#include "edu_wpi_first_math_WPIMathJNI.h"
+#include "frc/trajectory/TrajectoryUtil.h"
#include "unsupported/Eigen/MatrixFunctions"
using namespace wpi::java;
+/**
+ * Returns true if (A, B) is a stabilizable pair.
+ *
+ * (A, B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
+ * any, have absolute values less than one, where an eigenvalue is
+ * uncontrollable if rank(λI - A, B) < n where n is the number of states.
+ *
+ * @param A System matrix.
+ * @param B Input matrix.
+ */
bool check_stabilizable(const Eigen::Ref<const Eigen::MatrixXd>& A,
const Eigen::Ref<const Eigen::MatrixXd>& B) {
- // This function checks if (A,B) is a stabilizable pair.
- // (A,B) is stabilizable if and only if the uncontrollable eigenvalues of
- // A, if any, have absolute values less than one, where an eigenvalue is
- // uncontrollable if Rank[lambda * I - A, B] < n.
- int n = B.rows(), m = B.cols();
- Eigen::EigenSolver<Eigen::MatrixXd> es(A);
- for (int i = 0; i < n; i++) {
+ int states = B.rows();
+ int inputs = B.cols();
+ Eigen::EigenSolver<Eigen::MatrixXd> es{A};
+ for (int i = 0; i < states; ++i) {
if (es.eigenvalues()[i].real() * es.eigenvalues()[i].real() +
es.eigenvalues()[i].imag() * es.eigenvalues()[i].imag() <
- 1)
+ 1) {
continue;
+ }
- Eigen::MatrixXcd E(n, n + m);
- E << es.eigenvalues()[i] * Eigen::MatrixXcd::Identity(n, n) - A, B;
- Eigen::ColPivHouseholderQR<Eigen::MatrixXcd> qr(E);
- if (qr.rank() != n) {
+ Eigen::MatrixXcd E{states, states + inputs};
+ E << es.eigenvalues()[i] * Eigen::MatrixXcd::Identity(states, states) - A,
+ B;
+ Eigen::ColPivHouseholderQR<Eigen::MatrixXcd> qr{E};
+ if (qr.rank() < states) {
return false;
}
}
@@ -43,6 +52,46 @@
return true;
}
+std::vector<double> GetElementsFromTrajectory(
+ const frc::Trajectory& trajectory) {
+ std::vector<double> elements;
+ elements.reserve(trajectory.States().size() * 7);
+
+ for (auto&& state : trajectory.States()) {
+ elements.push_back(state.t.value());
+ elements.push_back(state.velocity.value());
+ elements.push_back(state.acceleration.value());
+ elements.push_back(state.pose.X().value());
+ elements.push_back(state.pose.Y().value());
+ elements.push_back(state.pose.Rotation().Radians().value());
+ elements.push_back(state.curvature.value());
+ }
+
+ return elements;
+}
+
+frc::Trajectory CreateTrajectoryFromElements(wpi::span<const double> elements) {
+ // Make sure that the elements have the correct length.
+ assert(elements.size() % 7 == 0);
+
+ // Create a vector of states from the elements.
+ std::vector<frc::Trajectory::State> states;
+ states.reserve(elements.size() / 7);
+
+ for (size_t i = 0; i < elements.size(); i += 7) {
+ states.emplace_back(frc::Trajectory::State{
+ units::second_t{elements[i]},
+ units::meters_per_second_t{elements[i + 1]},
+ units::meters_per_second_squared_t{elements[i + 2]},
+ frc::Pose2d{units::meter_t{elements[i + 3]},
+ units::meter_t{elements[i + 4]},
+ units::radian_t{elements[i + 5]}},
+ units::curvature_t{elements[i + 6]}});
+ }
+
+ return frc::Trajectory(states);
+}
+
extern "C" {
/*
@@ -73,15 +122,22 @@
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
Rmat{nativeR, inputs, inputs};
- Eigen::MatrixXd result =
- drake::math::DiscreteAlgebraicRiccatiEquation(Amat, Bmat, Qmat, Rmat);
+ try {
+ Eigen::MatrixXd result =
+ drake::math::DiscreteAlgebraicRiccatiEquation(Amat, Bmat, Qmat, Rmat);
- env->ReleaseDoubleArrayElements(A, nativeA, 0);
- env->ReleaseDoubleArrayElements(B, nativeB, 0);
- env->ReleaseDoubleArrayElements(Q, nativeQ, 0);
- env->ReleaseDoubleArrayElements(R, nativeR, 0);
+ env->ReleaseDoubleArrayElements(A, nativeA, 0);
+ env->ReleaseDoubleArrayElements(B, nativeB, 0);
+ env->ReleaseDoubleArrayElements(Q, nativeQ, 0);
+ env->ReleaseDoubleArrayElements(R, nativeR, 0);
- env->SetDoubleArrayRegion(S, 0, states * states, result.data());
+ env->SetDoubleArrayRegion(S, 0, states * states, result.data());
+ } catch (const std::runtime_error& e) {
+ jclass cls = env->FindClass("java/lang/RuntimeException");
+ if (cls) {
+ env->ThrowNew(cls, e.what());
+ }
+ }
}
/*
@@ -107,6 +163,28 @@
/*
* Class: edu_wpi_first_math_WPIMathJNI
+ * Method: pow
+ * Signature: ([DID[D)V
+ */
+JNIEXPORT void JNICALL
+Java_edu_wpi_first_math_WPIMathJNI_pow
+ (JNIEnv* env, jclass, jdoubleArray src, jint rows, jdouble exponent,
+ jdoubleArray dst)
+{
+ jdouble* arrayBody = env->GetDoubleArrayElements(src, nullptr);
+
+ Eigen::Map<
+ Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ Amat{arrayBody, rows, rows}; // NOLINT
+ Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Apow =
+ Amat.pow(exponent);
+
+ env->ReleaseDoubleArrayElements(src, arrayBody, 0);
+ env->SetDoubleArrayRegion(dst, 0, rows * rows, Apow.data());
+}
+
+/*
+ * Class: edu_wpi_first_math_WPIMathJNI
* Method: isStabilizable
* Signature: (II[D[D)Z
*/
@@ -134,4 +212,99 @@
return isStabilizable;
}
+/*
+ * Class: edu_wpi_first_math_WPIMathJNI
+ * Method: fromPathweaverJson
+ * Signature: (Ljava/lang/String;)[D
+ */
+JNIEXPORT jdoubleArray JNICALL
+Java_edu_wpi_first_math_WPIMathJNI_fromPathweaverJson
+ (JNIEnv* env, jclass, jstring path)
+{
+ try {
+ auto trajectory =
+ frc::TrajectoryUtil::FromPathweaverJson(JStringRef{env, path}.c_str());
+ std::vector<double> elements = GetElementsFromTrajectory(trajectory);
+ return MakeJDoubleArray(env, elements);
+ } catch (std::exception& e) {
+ jclass cls = env->FindClass("java/io/IOException");
+ if (cls) {
+ env->ThrowNew(cls, e.what());
+ }
+ return nullptr;
+ }
+}
+
+/*
+ * Class: edu_wpi_first_math_WPIMathJNI
+ * Method: toPathweaverJson
+ * Signature: ([DLjava/lang/String;)V
+ */
+JNIEXPORT void JNICALL
+Java_edu_wpi_first_math_WPIMathJNI_toPathweaverJson
+ (JNIEnv* env, jclass, jdoubleArray elements, jstring path)
+{
+ try {
+ auto trajectory =
+ CreateTrajectoryFromElements(JDoubleArrayRef{env, elements});
+ frc::TrajectoryUtil::ToPathweaverJson(trajectory,
+ JStringRef{env, path}.c_str());
+ } catch (std::exception& e) {
+ jclass cls = env->FindClass("java/io/IOException");
+ if (cls) {
+ env->ThrowNew(cls, e.what());
+ }
+ }
+}
+
+/*
+ * Class: edu_wpi_first_math_WPIMathJNI
+ * Method: deserializeTrajectory
+ * Signature: (Ljava/lang/String;)[D
+ */
+JNIEXPORT jdoubleArray JNICALL
+Java_edu_wpi_first_math_WPIMathJNI_deserializeTrajectory
+ (JNIEnv* env, jclass, jstring json)
+{
+ try {
+ auto trajectory = frc::TrajectoryUtil::DeserializeTrajectory(
+ JStringRef{env, json}.c_str());
+ std::vector<double> elements = GetElementsFromTrajectory(trajectory);
+ return MakeJDoubleArray(env, elements);
+ } catch (std::exception& e) {
+ jclass cls = env->FindClass(
+ "edu/wpi/first/math/trajectory/TrajectoryUtil$"
+ "TrajectorySerializationException");
+ if (cls) {
+ env->ThrowNew(cls, e.what());
+ }
+ return nullptr;
+ }
+}
+
+/*
+ * Class: edu_wpi_first_math_WPIMathJNI
+ * Method: serializeTrajectory
+ * Signature: ([D)Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL
+Java_edu_wpi_first_math_WPIMathJNI_serializeTrajectory
+ (JNIEnv* env, jclass, jdoubleArray elements)
+{
+ try {
+ auto trajectory =
+ CreateTrajectoryFromElements(JDoubleArrayRef{env, elements});
+ return MakeJString(env,
+ frc::TrajectoryUtil::SerializeTrajectory(trajectory));
+ } catch (std::exception& e) {
+ jclass cls = env->FindClass(
+ "edu/wpi/first/math/trajectory/TrajectoryUtil$"
+ "TrajectorySerializationException");
+ if (cls) {
+ env->ThrowNew(cls, e.what());
+ }
+ return nullptr;
+ }
+}
+
} // extern "C"
diff --git a/wpimath/src/main/native/cpp/kinematics/DifferentialDriveOdometry.cpp b/wpimath/src/main/native/cpp/kinematics/DifferentialDriveOdometry.cpp
index 25b6c51..c4a4311 100644
--- a/wpimath/src/main/native/cpp/kinematics/DifferentialDriveOdometry.cpp
+++ b/wpimath/src/main/native/cpp/kinematics/DifferentialDriveOdometry.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/DifferentialDriveOdometry.h"
diff --git a/wpimath/src/main/native/cpp/kinematics/DifferentialDriveWheelSpeeds.cpp b/wpimath/src/main/native/cpp/kinematics/DifferentialDriveWheelSpeeds.cpp
index 36a4952..42d3018 100644
--- a/wpimath/src/main/native/cpp/kinematics/DifferentialDriveWheelSpeeds.cpp
+++ b/wpimath/src/main/native/cpp/kinematics/DifferentialDriveWheelSpeeds.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/DifferentialDriveWheelSpeeds.h"
diff --git a/wpimath/src/main/native/cpp/kinematics/MecanumDriveKinematics.cpp b/wpimath/src/main/native/cpp/kinematics/MecanumDriveKinematics.cpp
index de1b2d0..c4a71fd 100644
--- a/wpimath/src/main/native/cpp/kinematics/MecanumDriveKinematics.cpp
+++ b/wpimath/src/main/native/cpp/kinematics/MecanumDriveKinematics.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/MecanumDriveKinematics.h"
@@ -24,33 +21,31 @@
m_previousCoR = centerOfRotation;
}
- Eigen::Vector3d chassisSpeedsVector;
- chassisSpeedsVector << chassisSpeeds.vx.to<double>(),
- chassisSpeeds.vy.to<double>(), chassisSpeeds.omega.to<double>();
+ Eigen::Vector3d chassisSpeedsVector{chassisSpeeds.vx.value(),
+ chassisSpeeds.vy.value(),
+ chassisSpeeds.omega.value()};
- Eigen::Matrix<double, 4, 1> wheelsMatrix =
+ Eigen::Vector<double, 4> wheelsVector =
m_inverseKinematics * chassisSpeedsVector;
MecanumDriveWheelSpeeds wheelSpeeds;
- wheelSpeeds.frontLeft = units::meters_per_second_t{wheelsMatrix(0, 0)};
- wheelSpeeds.frontRight = units::meters_per_second_t{wheelsMatrix(1, 0)};
- wheelSpeeds.rearLeft = units::meters_per_second_t{wheelsMatrix(2, 0)};
- wheelSpeeds.rearRight = units::meters_per_second_t{wheelsMatrix(3, 0)};
+ wheelSpeeds.frontLeft = units::meters_per_second_t{wheelsVector(0)};
+ wheelSpeeds.frontRight = units::meters_per_second_t{wheelsVector(1)};
+ wheelSpeeds.rearLeft = units::meters_per_second_t{wheelsVector(2)};
+ wheelSpeeds.rearRight = units::meters_per_second_t{wheelsVector(3)};
return wheelSpeeds;
}
ChassisSpeeds MecanumDriveKinematics::ToChassisSpeeds(
const MecanumDriveWheelSpeeds& wheelSpeeds) const {
- Eigen::Matrix<double, 4, 1> wheelSpeedsMatrix;
- // clang-format off
- wheelSpeedsMatrix << wheelSpeeds.frontLeft.to<double>(), wheelSpeeds.frontRight.to<double>(),
- wheelSpeeds.rearLeft.to<double>(), wheelSpeeds.rearRight.to<double>();
- // clang-format on
+ Eigen::Vector<double, 4> wheelSpeedsVector{
+ wheelSpeeds.frontLeft.value(), wheelSpeeds.frontRight.value(),
+ wheelSpeeds.rearLeft.value(), wheelSpeeds.rearRight.value()};
Eigen::Vector3d chassisSpeedsVector =
- m_forwardKinematics.solve(wheelSpeedsMatrix);
+ m_forwardKinematics.solve(wheelSpeedsVector);
- return {units::meters_per_second_t{chassisSpeedsVector(0)},
+ return {units::meters_per_second_t{chassisSpeedsVector(0)}, // NOLINT
units::meters_per_second_t{chassisSpeedsVector(1)},
units::radians_per_second_t{chassisSpeedsVector(2)}};
}
@@ -59,11 +54,9 @@
Translation2d fr,
Translation2d rl,
Translation2d rr) const {
- // clang-format off
- m_inverseKinematics << 1, -1, (-(fl.X() + fl.Y())).template to<double>(),
- 1, 1, (fr.X() - fr.Y()).template to<double>(),
- 1, 1, (rl.X() - rl.Y()).template to<double>(),
- 1, -1, (-(rr.X() + rr.Y())).template to<double>();
- // clang-format on
- m_inverseKinematics /= std::sqrt(2);
+ m_inverseKinematics =
+ Eigen::Matrix<double, 4, 3>{{1, -1, (-(fl.X() + fl.Y())).value()},
+ {1, 1, (fr.X() - fr.Y()).value()},
+ {1, 1, (rl.X() - rl.Y()).value()},
+ {1, -1, (-(rr.X() + rr.Y())).value()}};
}
diff --git a/wpimath/src/main/native/cpp/kinematics/MecanumDriveOdometry.cpp b/wpimath/src/main/native/cpp/kinematics/MecanumDriveOdometry.cpp
index 7534fc1..bbeee58 100644
--- a/wpimath/src/main/native/cpp/kinematics/MecanumDriveOdometry.cpp
+++ b/wpimath/src/main/native/cpp/kinematics/MecanumDriveOdometry.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/MecanumDriveOdometry.h"
diff --git a/wpimath/src/main/native/cpp/kinematics/MecanumDriveWheelSpeeds.cpp b/wpimath/src/main/native/cpp/kinematics/MecanumDriveWheelSpeeds.cpp
index b20dddf..fc47461 100644
--- a/wpimath/src/main/native/cpp/kinematics/MecanumDriveWheelSpeeds.cpp
+++ b/wpimath/src/main/native/cpp/kinematics/MecanumDriveWheelSpeeds.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/MecanumDriveWheelSpeeds.h"
diff --git a/wpimath/src/main/native/cpp/spline/CubicHermiteSpline.cpp b/wpimath/src/main/native/cpp/spline/CubicHermiteSpline.cpp
index c00a362..b643849 100644
--- a/wpimath/src/main/native/cpp/spline/CubicHermiteSpline.cpp
+++ b/wpimath/src/main/native/cpp/spline/CubicHermiteSpline.cpp
@@ -1,19 +1,16 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/spline/CubicHermiteSpline.h"
using namespace frc;
CubicHermiteSpline::CubicHermiteSpline(
- std::array<double, 2> xInitialControlVector,
- std::array<double, 2> xFinalControlVector,
- std::array<double, 2> yInitialControlVector,
- std::array<double, 2> yFinalControlVector) {
+ wpi::array<double, 2> xInitialControlVector,
+ wpi::array<double, 2> xFinalControlVector,
+ wpi::array<double, 2> yInitialControlVector,
+ wpi::array<double, 2> yFinalControlVector) {
const auto hermite = MakeHermiteBasis();
const auto x =
ControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
diff --git a/wpimath/src/main/native/cpp/spline/QuinticHermiteSpline.cpp b/wpimath/src/main/native/cpp/spline/QuinticHermiteSpline.cpp
index 5b34cdb..5362b7c 100644
--- a/wpimath/src/main/native/cpp/spline/QuinticHermiteSpline.cpp
+++ b/wpimath/src/main/native/cpp/spline/QuinticHermiteSpline.cpp
@@ -1,19 +1,16 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/spline/QuinticHermiteSpline.h"
using namespace frc;
QuinticHermiteSpline::QuinticHermiteSpline(
- std::array<double, 3> xInitialControlVector,
- std::array<double, 3> xFinalControlVector,
- std::array<double, 3> yInitialControlVector,
- std::array<double, 3> yFinalControlVector) {
+ wpi::array<double, 3> xInitialControlVector,
+ wpi::array<double, 3> xFinalControlVector,
+ wpi::array<double, 3> yInitialControlVector,
+ wpi::array<double, 3> yFinalControlVector) {
const auto hermite = MakeHermiteBasis();
const auto x =
ControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
diff --git a/wpimath/src/main/native/cpp/spline/SplineHelper.cpp b/wpimath/src/main/native/cpp/spline/SplineHelper.cpp
index 58f7537..cec620c 100644
--- a/wpimath/src/main/native/cpp/spline/SplineHelper.cpp
+++ b/wpimath/src/main/native/cpp/spline/SplineHelper.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/spline/SplineHelper.h"
@@ -16,10 +13,10 @@
const Spline<3>::ControlVector& end) {
std::vector<CubicHermiteSpline> splines;
- std::array<double, 2> xInitial = start.x;
- std::array<double, 2> yInitial = start.y;
- std::array<double, 2> xFinal = end.x;
- std::array<double, 2> yFinal = end.y;
+ wpi::array<double, 2> xInitial = start.x;
+ wpi::array<double, 2> yInitial = start.y;
+ wpi::array<double, 2> xFinal = end.x;
+ wpi::array<double, 2> yFinal = end.y;
if (waypoints.size() > 1) {
waypoints.emplace(waypoints.begin(),
@@ -55,29 +52,27 @@
c.emplace_back(0);
// populate rhs vectors
- dx.emplace_back(
- 3 * (waypoints[2].X().to<double>() - waypoints[0].X().to<double>()) -
- xInitial[1]);
- dy.emplace_back(
- 3 * (waypoints[2].Y().to<double>() - waypoints[0].Y().to<double>()) -
- yInitial[1]);
+ dx.emplace_back(3 * (waypoints[2].X().value() - waypoints[0].X().value()) -
+ xInitial[1]);
+ dy.emplace_back(3 * (waypoints[2].Y().value() - waypoints[0].Y().value()) -
+ yInitial[1]);
if (waypoints.size() > 4) {
for (size_t i = 1; i <= waypoints.size() - 4; ++i) {
// dx and dy represent the derivatives of the internal waypoints. The
// derivative of the second internal waypoint should involve the third
// and first internal waypoint, which have indices of 1 and 3 in the
// waypoints list (which contains ALL waypoints).
- dx.emplace_back(3 * (waypoints[i + 2].X().to<double>() -
- waypoints[i].X().to<double>()));
- dy.emplace_back(3 * (waypoints[i + 2].Y().to<double>() -
- waypoints[i].Y().to<double>()));
+ dx.emplace_back(
+ 3 * (waypoints[i + 2].X().value() - waypoints[i].X().value()));
+ dy.emplace_back(
+ 3 * (waypoints[i + 2].Y().value() - waypoints[i].Y().value()));
}
}
- dx.emplace_back(3 * (waypoints[waypoints.size() - 1].X().to<double>() -
- waypoints[waypoints.size() - 3].X().to<double>()) -
+ dx.emplace_back(3 * (waypoints[waypoints.size() - 1].X().value() -
+ waypoints[waypoints.size() - 3].X().value()) -
xFinal[1]);
- dy.emplace_back(3 * (waypoints[waypoints.size() - 1].Y().to<double>() -
- waypoints[waypoints.size() - 3].Y().to<double>()) -
+ dy.emplace_back(3 * (waypoints[waypoints.size() - 1].Y().value() -
+ waypoints[waypoints.size() - 3].Y().value()) -
yFinal[1]);
// Compute solution to tridiagonal system
@@ -92,10 +87,10 @@
for (size_t i = 0; i < fx.size() - 1; ++i) {
// Create the spline.
const CubicHermiteSpline spline{
- {waypoints[i].X().to<double>(), fx[i]},
- {waypoints[i + 1].X().to<double>(), fx[i + 1]},
- {waypoints[i].Y().to<double>(), fy[i]},
- {waypoints[i + 1].Y().to<double>(), fy[i + 1]}};
+ {waypoints[i].X().value(), fx[i]},
+ {waypoints[i + 1].X().value(), fx[i + 1]},
+ {waypoints[i].Y().value(), fy[i]},
+ {waypoints[i + 1].Y().value(), fy[i + 1]}};
splines.push_back(spline);
}
@@ -105,10 +100,8 @@
const double yDeriv =
(3 * (yFinal[0] - yInitial[0]) - yFinal[1] - yInitial[1]) / 4.0;
- std::array<double, 2> midXControlVector{waypoints[0].X().to<double>(),
- xDeriv};
- std::array<double, 2> midYControlVector{waypoints[0].Y().to<double>(),
- yDeriv};
+ wpi::array<double, 2> midXControlVector{waypoints[0].X().value(), xDeriv};
+ wpi::array<double, 2> midYControlVector{waypoints[0].Y().value(), yDeriv};
splines.emplace_back(xInitial, midXControlVector, yInitial,
midYControlVector);
@@ -137,22 +130,20 @@
return splines;
}
-std::array<Spline<3>::ControlVector, 2>
+wpi::array<Spline<3>::ControlVector, 2>
SplineHelper::CubicControlVectorsFromWaypoints(
const Pose2d& start, const std::vector<Translation2d>& interiorWaypoints,
const Pose2d& end) {
double scalar;
if (interiorWaypoints.empty()) {
- scalar = 1.2 * start.Translation().Distance(end.Translation()).to<double>();
+ scalar = 1.2 * start.Translation().Distance(end.Translation()).value();
} else {
scalar =
- 1.2 *
- start.Translation().Distance(interiorWaypoints.front()).to<double>();
+ 1.2 * start.Translation().Distance(interiorWaypoints.front()).value();
}
const auto initialCV = CubicControlVector(scalar, start);
if (!interiorWaypoints.empty()) {
- scalar =
- 1.2 * end.Translation().Distance(interiorWaypoints.back()).to<double>();
+ scalar = 1.2 * end.Translation().Distance(interiorWaypoints.back()).value();
}
const auto finalCV = CubicControlVector(scalar, end);
return {initialCV, finalCV};
@@ -168,7 +159,7 @@
// This just makes the splines look better.
const auto scalar =
- 1.2 * p0.Translation().Distance(p1.Translation()).to<double>();
+ 1.2 * p0.Translation().Distance(p1.Translation()).value();
auto controlVectorA = QuinticControlVector(scalar, p0);
auto controlVectorB = QuinticControlVector(scalar, p1);
diff --git a/wpimath/src/main/native/cpp/spline/SplineParameterizer.cpp b/wpimath/src/main/native/cpp/spline/SplineParameterizer.cpp
index b7e7f9e..73c475b 100644
--- a/wpimath/src/main/native/cpp/spline/SplineParameterizer.cpp
+++ b/wpimath/src/main/native/cpp/spline/SplineParameterizer.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/spline/SplineParameterizer.h"
diff --git a/wpimath/src/main/native/cpp/trajectory/Trajectory.cpp b/wpimath/src/main/native/cpp/trajectory/Trajectory.cpp
index 067b8de..db419f7 100644
--- a/wpimath/src/main/native/cpp/trajectory/Trajectory.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/Trajectory.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/Trajectory.h"
@@ -34,7 +31,9 @@
const auto deltaT = newT - t;
// If delta time is negative, flip the order of interpolation.
- if (deltaT < 0_s) return endValue.Interpolate(*this, 1.0 - i);
+ if (deltaT < 0_s) {
+ return endValue.Interpolate(*this, 1.0 - i);
+ }
// Check whether the robot is reversing at this stage.
const auto reversing =
@@ -68,8 +67,12 @@
}
Trajectory::State Trajectory::Sample(units::second_t t) const {
- if (t <= m_states.front().t) return m_states.front();
- if (t >= m_totalTime) return m_states.back();
+ if (t <= m_states.front().t) {
+ return m_states.front();
+ }
+ if (t >= m_totalTime) {
+ return m_states.back();
+ }
// Use binary search to get the element with a timestamp no less than the
// requested timestamp. This starts at 1 because we use the previous state
@@ -121,12 +124,33 @@
return Trajectory(newStates);
}
+Trajectory Trajectory::operator+(const Trajectory& other) const {
+ // If this is a default constructed trajectory with no states, then we can
+ // simply return the rhs trajectory.
+ if (m_states.empty()) {
+ return other;
+ }
+
+ auto states = m_states;
+ auto otherStates = other.States();
+ for (auto& otherState : otherStates) {
+ otherState.t += m_totalTime;
+ }
+
+ // Here we omit the first state of the other trajectory because we don't want
+ // two time points with different states. Sample() will automatically
+ // interpolate between the end of this trajectory and the second state of the
+ // other trajectory.
+ states.insert(states.end(), otherStates.begin() + 1, otherStates.end());
+ return Trajectory(states);
+}
+
void frc::to_json(wpi::json& json, const Trajectory::State& state) {
- json = wpi::json{{"time", state.t.to<double>()},
- {"velocity", state.velocity.to<double>()},
- {"acceleration", state.acceleration.to<double>()},
+ json = wpi::json{{"time", state.t.value()},
+ {"velocity", state.velocity.value()},
+ {"acceleration", state.acceleration.value()},
{"pose", state.pose},
- {"curvature", state.curvature.to<double>()}};
+ {"curvature", state.curvature.value()}};
}
void frc::from_json(const wpi::json& json, Trajectory::State& state) {
diff --git a/wpimath/src/main/native/cpp/trajectory/TrajectoryGenerator.cpp b/wpimath/src/main/native/cpp/trajectory/TrajectoryGenerator.cpp
index 6cb3f7a..2e2771d 100644
--- a/wpimath/src/main/native/cpp/trajectory/TrajectoryGenerator.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/TrajectoryGenerator.cpp
@@ -1,15 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/TrajectoryGenerator.h"
#include <utility>
-#include <wpi/raw_ostream.h>
+#include <fmt/format.h>
#include "frc/spline/SplineHelper.h"
#include "frc/spline/SplineParameterizer.h"
@@ -22,10 +19,11 @@
std::function<void(const char*)> TrajectoryGenerator::s_errorFunc;
void TrajectoryGenerator::ReportError(const char* error) {
- if (s_errorFunc)
+ if (s_errorFunc) {
s_errorFunc(error);
- else
- wpi::errs() << "TrajectoryGenerator error: " << error << "\n";
+ } else {
+ fmt::print(stderr, "TrajectoryGenerator error: {}\n", error);
+ }
}
Trajectory TrajectoryGenerator::GenerateTrajectory(
@@ -115,8 +113,11 @@
const std::vector<Pose2d>& waypoints, const TrajectoryConfig& config) {
auto newWaypoints = waypoints;
const Transform2d flip{Translation2d(), Rotation2d(180_deg)};
- if (config.IsReversed())
- for (auto& waypoint : newWaypoints) waypoint += flip;
+ if (config.IsReversed()) {
+ for (auto& waypoint : newWaypoints) {
+ waypoint = waypoint + flip;
+ }
+ }
std::vector<SplineParameterizer::PoseWithCurvature> points;
try {
diff --git a/wpimath/src/main/native/cpp/trajectory/TrajectoryParameterizer.cpp b/wpimath/src/main/native/cpp/trajectory/TrajectoryParameterizer.cpp
index 0e78a15..d397d0c 100644
--- a/wpimath/src/main/native/cpp/trajectory/TrajectoryParameterizer.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/TrajectoryParameterizer.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
/*
* MIT License
@@ -31,7 +28,7 @@
#include "frc/trajectory/TrajectoryParameterizer.h"
-#include <string>
+#include <fmt/format.h>
#include "units/math.h"
@@ -88,7 +85,9 @@
// Now enforce all acceleration limits.
EnforceAccelerationLimits(reversed, constraints, &constrainedState);
- if (ds.to<double>() < kEpsilon) break;
+ if (ds.value() < kEpsilon) {
+ break;
+ }
// If the actual acceleration for this state is higher than the max
// acceleration that we applied, then we need to reduce the max
@@ -133,14 +132,18 @@
successor.minAcceleration * ds * 2.0);
// No more limits to impose! This state can be finalized.
- if (newMaxVelocity >= constrainedState.maxVelocity) break;
+ if (newMaxVelocity >= constrainedState.maxVelocity) {
+ break;
+ }
constrainedState.maxVelocity = newMaxVelocity;
// Check all acceleration constraints with the new max velocity.
EnforceAccelerationLimits(reversed, constraints, &constrainedState);
- if (ds.to<double>() > -kEpsilon) break;
+ if (ds.value() > -kEpsilon) {
+ break;
+ }
// If the actual acceleration for this state is lower than the min
// acceleration, then we need to lower the min acceleration of the
@@ -190,9 +193,9 @@
// delta_x = v * t
dt = ds / v;
} else {
- throw std::runtime_error("Something went wrong at iteration " +
- std::to_string(i) +
- " of time parameterization.");
+ throw std::runtime_error(fmt::format(
+ "Something went wrong at iteration {} of time parameterization.",
+ i));
}
}
diff --git a/wpimath/src/main/native/cpp/trajectory/TrajectoryUtil.cpp b/wpimath/src/main/native/cpp/trajectory/TrajectoryUtil.cpp
index 0ae43f5..169b642 100644
--- a/wpimath/src/main/native/cpp/trajectory/TrajectoryUtil.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/TrajectoryUtil.cpp
@@ -1,14 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/TrajectoryUtil.h"
#include <system_error>
+#include <fmt/format.h>
#include <wpi/SmallString.h>
#include <wpi/json.h>
#include <wpi/raw_istream.h>
@@ -17,13 +15,12 @@
using namespace frc;
void TrajectoryUtil::ToPathweaverJson(const Trajectory& trajectory,
- const wpi::Twine& path) {
+ std::string_view path) {
std::error_code error_code;
- wpi::SmallString<128> buf;
- wpi::raw_fd_ostream output{path.toStringRef(buf), error_code};
+ wpi::raw_fd_ostream output{path, error_code};
if (error_code) {
- throw std::runtime_error(("Cannot open file: " + path).str());
+ throw std::runtime_error(fmt::format("Cannot open file: {}", path));
}
wpi::json json = trajectory.States();
@@ -31,13 +28,12 @@
output.flush();
}
-Trajectory TrajectoryUtil::FromPathweaverJson(const wpi::Twine& path) {
+Trajectory TrajectoryUtil::FromPathweaverJson(std::string_view path) {
std::error_code error_code;
- wpi::SmallString<128> buf;
- wpi::raw_fd_istream input{path.toStringRef(buf), error_code};
+ wpi::raw_fd_istream input{path, error_code};
if (error_code) {
- throw std::runtime_error(("Cannot open file: " + path).str());
+ throw std::runtime_error(fmt::format("Cannot open file: {}", path));
}
wpi::json json;
@@ -51,8 +47,7 @@
return json.dump();
}
-Trajectory TrajectoryUtil::DeserializeTrajectory(
- const wpi::StringRef json_str) {
- wpi::json json = wpi::json::parse(json_str);
+Trajectory TrajectoryUtil::DeserializeTrajectory(std::string_view jsonStr) {
+ wpi::json json = wpi::json::parse(jsonStr);
return Trajectory{json.get<std::vector<Trajectory::State>>()};
}
diff --git a/wpimath/src/main/native/cpp/trajectory/constraint/CentripetalAccelerationConstraint.cpp b/wpimath/src/main/native/cpp/trajectory/constraint/CentripetalAccelerationConstraint.cpp
index f04b9d6..738d243 100644
--- a/wpimath/src/main/native/cpp/trajectory/constraint/CentripetalAccelerationConstraint.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/constraint/CentripetalAccelerationConstraint.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/constraint/CentripetalAccelerationConstraint.h"
diff --git a/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveKinematicsConstraint.cpp b/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveKinematicsConstraint.cpp
index c3380e5..d1c2f6f 100644
--- a/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveKinematicsConstraint.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveKinematicsConstraint.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/constraint/DifferentialDriveKinematicsConstraint.h"
diff --git a/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveVoltageConstraint.cpp b/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveVoltageConstraint.cpp
index f12ce75..7c10201 100644
--- a/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveVoltageConstraint.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/constraint/DifferentialDriveVoltageConstraint.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/constraint/DifferentialDriveVoltageConstraint.h"
diff --git a/wpimath/src/main/native/cpp/trajectory/constraint/MaxVelocityConstraint.cpp b/wpimath/src/main/native/cpp/trajectory/constraint/MaxVelocityConstraint.cpp
new file mode 100644
index 0000000..9e6e712
--- /dev/null
+++ b/wpimath/src/main/native/cpp/trajectory/constraint/MaxVelocityConstraint.cpp
@@ -0,0 +1,23 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/trajectory/constraint/MaxVelocityConstraint.h"
+
+using namespace frc;
+
+MaxVelocityConstraint::MaxVelocityConstraint(
+ units::meters_per_second_t maxVelocity)
+ : m_maxVelocity(units::math::abs(maxVelocity)) {}
+
+units::meters_per_second_t MaxVelocityConstraint::MaxVelocity(
+ const Pose2d& pose, units::curvature_t curvature,
+ units::meters_per_second_t velocity) const {
+ return m_maxVelocity;
+}
+
+TrajectoryConstraint::MinMax MaxVelocityConstraint::MinMaxAcceleration(
+ const Pose2d& pose, units::curvature_t curvature,
+ units::meters_per_second_t speed) const {
+ return {};
+}
diff --git a/wpimath/src/main/native/cpp/trajectory/constraint/MecanumDriveKinematicsConstraint.cpp b/wpimath/src/main/native/cpp/trajectory/constraint/MecanumDriveKinematicsConstraint.cpp
index 7d95803..418a904 100644
--- a/wpimath/src/main/native/cpp/trajectory/constraint/MecanumDriveKinematicsConstraint.cpp
+++ b/wpimath/src/main/native/cpp/trajectory/constraint/MecanumDriveKinematicsConstraint.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/constraint/MecanumDriveKinematicsConstraint.h"
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/Cholesky b/wpimath/src/main/native/eigeninclude/Eigen/Cholesky
index 1332b54..ef249de 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/Cholesky
+++ b/wpimath/src/main/native/eigeninclude/Eigen/Cholesky
@@ -33,14 +33,13 @@
#include "src/Cholesky/LDLT.h"
#ifdef EIGEN_USE_LAPACKE
#ifdef EIGEN_USE_MKL
-#include "mkl_lapacke.h"
+// #include "mkl_lapacke.h"
#else
-#include "src/misc/lapacke.h"
+// #include "src/misc/lapacke.h"
#endif
-#include "src/Cholesky/LLT_LAPACKE.h"
+// #include "src/Cholesky/LLT_LAPACKE.h"
#endif
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_CHOLESKY_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/Core b/wpimath/src/main/native/eigeninclude/Eigen/Core
index bd892c1..fd5e098 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/Core
+++ b/wpimath/src/main/native/eigeninclude/Eigen/Core
@@ -11,255 +11,55 @@
#ifndef EIGEN_CORE_H
#define EIGEN_CORE_H
-#if __GNUC__ >= 9
-#pragma GCC diagnostic ignored "-Wdeprecated-copy"
-#endif
-
-// first thing Eigen does: stop the compiler from committing suicide
+// first thing Eigen does: stop the compiler from reporting useless warnings.
#include "src/Core/util/DisableStupidWarnings.h"
-#if defined(__CUDACC__) && !defined(EIGEN_NO_CUDA)
- #define EIGEN_CUDACC __CUDACC__
+// then include this file where all our macros are defined. It's really important to do it first because
+// it's where we do all the compiler/OS/arch detections and define most defaults.
+#include "src/Core/util/Macros.h"
+
+// This detects SSE/AVX/NEON/etc. and configure alignment settings
+#include "src/Core/util/ConfigureVectorization.h"
+
+// We need cuda_runtime.h/hip_runtime.h to ensure that
+// the EIGEN_USING_STD macro works properly on the device side
+#if defined(EIGEN_CUDACC)
+ #include <cuda_runtime.h>
+#elif defined(EIGEN_HIPCC)
+ #include <hip/hip_runtime.h>
#endif
-#if defined(__CUDA_ARCH__) && !defined(EIGEN_NO_CUDA)
- #define EIGEN_CUDA_ARCH __CUDA_ARCH__
-#endif
-
-#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9)
-#define EIGEN_CUDACC_VER ((__CUDACC_VER_MAJOR__ * 10000) + (__CUDACC_VER_MINOR__ * 100))
-#elif defined(__CUDACC_VER__)
-#define EIGEN_CUDACC_VER __CUDACC_VER__
-#else
-#define EIGEN_CUDACC_VER 0
-#endif
-
-// Handle NVCC/CUDA/SYCL
-#if defined(__CUDACC__) || defined(__SYCL_DEVICE_ONLY__)
- // Do not try asserts on CUDA and SYCL!
- #ifndef EIGEN_NO_DEBUG
- #define EIGEN_NO_DEBUG
- #endif
-
- #ifdef EIGEN_INTERNAL_DEBUGGING
- #undef EIGEN_INTERNAL_DEBUGGING
- #endif
-
- #ifdef EIGEN_EXCEPTIONS
- #undef EIGEN_EXCEPTIONS
- #endif
-
- // All functions callable from CUDA code must be qualified with __device__
- #ifdef __CUDACC__
- // Do not try to vectorize on CUDA and SYCL!
- #ifndef EIGEN_DONT_VECTORIZE
- #define EIGEN_DONT_VECTORIZE
- #endif
-
- #define EIGEN_DEVICE_FUNC __host__ __device__
- // We need cuda_runtime.h to ensure that that EIGEN_USING_STD_MATH macro
- // works properly on the device side
- #include <cuda_runtime.h>
- #else
- #define EIGEN_DEVICE_FUNC
- #endif
-
-#else
- #define EIGEN_DEVICE_FUNC
-
-#endif
-
-// When compiling CUDA device code with NVCC, pull in math functions from the
-// global namespace. In host mode, and when device doee with clang, use the
-// std versions.
-#if defined(__CUDA_ARCH__) && defined(__NVCC__)
- #define EIGEN_USING_STD_MATH(FUNC) using ::FUNC;
-#else
- #define EIGEN_USING_STD_MATH(FUNC) using std::FUNC;
-#endif
-
-#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(__CUDA_ARCH__) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL)
- #define EIGEN_EXCEPTIONS
-#endif
#ifdef EIGEN_EXCEPTIONS
#include <new>
#endif
-// then include this file where all our macros are defined. It's really important to do it first because
-// it's where we do all the alignment settings (platform detection and honoring the user's will if he
-// defined e.g. EIGEN_DONT_ALIGN) so it needs to be done before we do anything with vectorization.
-#include "src/Core/util/Macros.h"
-
// Disable the ipa-cp-clone optimization flag with MinGW 6.x or newer (enabled by default with -O3)
// See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=556 for details.
-#if EIGEN_COMP_MINGW && EIGEN_GNUC_AT_LEAST(4,6)
+#if EIGEN_COMP_MINGW && EIGEN_GNUC_AT_LEAST(4,6) && EIGEN_GNUC_AT_MOST(5,5)
#pragma GCC optimize ("-fno-ipa-cp-clone")
#endif
+// Prevent ICC from specializing std::complex operators that silently fail
+// on device. This allows us to use our own device-compatible specializations
+// instead.
+#if defined(EIGEN_COMP_ICC) && defined(EIGEN_GPU_COMPILE_PHASE) \
+ && !defined(_OVERRIDE_COMPLEX_SPECIALIZATION_)
+#define _OVERRIDE_COMPLEX_SPECIALIZATION_ 1
+#endif
#include <complex>
// this include file manages BLAS and MKL related macros
// and inclusion of their respective header files
// #include "src/Core/util/MKL_support.h"
-// if alignment is disabled, then disable vectorization. Note: EIGEN_MAX_ALIGN_BYTES is the proper check, it takes into
-// account both the user's will (EIGEN_MAX_ALIGN_BYTES,EIGEN_DONT_ALIGN) and our own platform checks
-#if EIGEN_MAX_ALIGN_BYTES==0
- #ifndef EIGEN_DONT_VECTORIZE
- #define EIGEN_DONT_VECTORIZE
- #endif
+
+#if defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16)
+ #define EIGEN_HAS_GPU_FP16
#endif
-#if EIGEN_COMP_MSVC
- #include <malloc.h> // for _aligned_malloc -- need it regardless of whether vectorization is enabled
- #if (EIGEN_COMP_MSVC >= 1500) // 2008 or later
- // Remember that usage of defined() in a #define is undefined by the standard.
- // a user reported that in 64-bit mode, MSVC doesn't care to define _M_IX86_FP.
- #if (defined(_M_IX86_FP) && (_M_IX86_FP >= 2)) || EIGEN_ARCH_x86_64
- #define EIGEN_SSE2_ON_MSVC_2008_OR_LATER
- #endif
- #endif
-#else
- // Remember that usage of defined() in a #define is undefined by the standard
- #if (defined __SSE2__) && ( (!EIGEN_COMP_GNUC) || EIGEN_COMP_ICC || EIGEN_GNUC_AT_LEAST(4,2) )
- #define EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC
- #endif
-#endif
-
-#ifndef EIGEN_DONT_VECTORIZE
-
- #if defined (EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC) || defined(EIGEN_SSE2_ON_MSVC_2008_OR_LATER)
-
- // Defines symbols for compile-time detection of which instructions are
- // used.
- // EIGEN_VECTORIZE_YY is defined if and only if the instruction set YY is used
- #define EIGEN_VECTORIZE
- #define EIGEN_VECTORIZE_SSE
- #define EIGEN_VECTORIZE_SSE2
-
- // Detect sse3/ssse3/sse4:
- // gcc and icc defines __SSE3__, ...
- // there is no way to know about this on msvc. You can define EIGEN_VECTORIZE_SSE* if you
- // want to force the use of those instructions with msvc.
- #ifdef __SSE3__
- #define EIGEN_VECTORIZE_SSE3
- #endif
- #ifdef __SSSE3__
- #define EIGEN_VECTORIZE_SSSE3
- #endif
- #ifdef __SSE4_1__
- #define EIGEN_VECTORIZE_SSE4_1
- #endif
- #ifdef __SSE4_2__
- #define EIGEN_VECTORIZE_SSE4_2
- #endif
- #ifdef __AVX__
- #define EIGEN_VECTORIZE_AVX
- #define EIGEN_VECTORIZE_SSE3
- #define EIGEN_VECTORIZE_SSSE3
- #define EIGEN_VECTORIZE_SSE4_1
- #define EIGEN_VECTORIZE_SSE4_2
- #endif
- #ifdef __AVX2__
- #define EIGEN_VECTORIZE_AVX2
- #endif
- #ifdef __FMA__
- #define EIGEN_VECTORIZE_FMA
- #endif
- #if defined(__AVX512F__) && defined(EIGEN_ENABLE_AVX512)
- #define EIGEN_VECTORIZE_AVX512
- #define EIGEN_VECTORIZE_AVX2
- #define EIGEN_VECTORIZE_AVX
- #define EIGEN_VECTORIZE_FMA
- #ifdef __AVX512DQ__
- #define EIGEN_VECTORIZE_AVX512DQ
- #endif
- #ifdef __AVX512ER__
- #define EIGEN_VECTORIZE_AVX512ER
- #endif
- #endif
-
- // include files
-
- // This extern "C" works around a MINGW-w64 compilation issue
- // https://sourceforge.net/tracker/index.php?func=detail&aid=3018394&group_id=202880&atid=983354
- // In essence, intrin.h is included by windows.h and also declares intrinsics (just as emmintrin.h etc. below do).
- // However, intrin.h uses an extern "C" declaration, and g++ thus complains of duplicate declarations
- // with conflicting linkage. The linkage for intrinsics doesn't matter, but at that stage the compiler doesn't know;
- // so, to avoid compile errors when windows.h is included after Eigen/Core, ensure intrinsics are extern "C" here too.
- // notice that since these are C headers, the extern "C" is theoretically needed anyways.
- extern "C" {
- // In theory we should only include immintrin.h and not the other *mmintrin.h header files directly.
- // Doing so triggers some issues with ICC. However old gcc versions seems to not have this file, thus:
- #if EIGEN_COMP_ICC >= 1110
- #include <immintrin.h>
- #else
- #include <mmintrin.h>
- #include <emmintrin.h>
- #include <xmmintrin.h>
- #ifdef EIGEN_VECTORIZE_SSE3
- #include <pmmintrin.h>
- #endif
- #ifdef EIGEN_VECTORIZE_SSSE3
- #include <tmmintrin.h>
- #endif
- #ifdef EIGEN_VECTORIZE_SSE4_1
- #include <smmintrin.h>
- #endif
- #ifdef EIGEN_VECTORIZE_SSE4_2
- #include <nmmintrin.h>
- #endif
- #if defined(EIGEN_VECTORIZE_AVX) || defined(EIGEN_VECTORIZE_AVX512)
- #include <immintrin.h>
- #endif
- #endif
- } // end extern "C"
- #elif defined __VSX__
- #define EIGEN_VECTORIZE
- #define EIGEN_VECTORIZE_VSX
- #include <altivec.h>
- // We need to #undef all these ugly tokens defined in <altivec.h>
- // => use __vector instead of vector
- #undef bool
- #undef vector
- #undef pixel
- #elif defined __ALTIVEC__
- #define EIGEN_VECTORIZE
- #define EIGEN_VECTORIZE_ALTIVEC
- #include <altivec.h>
- // We need to #undef all these ugly tokens defined in <altivec.h>
- // => use __vector instead of vector
- #undef bool
- #undef vector
- #undef pixel
- #elif (defined __ARM_NEON) || (defined __ARM_NEON__)
- #define EIGEN_VECTORIZE
- #define EIGEN_VECTORIZE_NEON
- #include <arm_neon.h>
- #elif (defined __s390x__ && defined __VEC__)
- #define EIGEN_VECTORIZE
- #define EIGEN_VECTORIZE_ZVECTOR
- #include <vecintrin.h>
- #endif
-#endif
-
-#if defined(__F16C__) && !defined(EIGEN_COMP_CLANG)
- // We can use the optimized fp16 to float and float to fp16 conversion routines
- #define EIGEN_HAS_FP16_C
-#endif
-
-#if defined __CUDACC__
- #define EIGEN_VECTORIZE_CUDA
- #include <vector_types.h>
- #if EIGEN_CUDACC_VER >= 70500
- #define EIGEN_HAS_CUDA_FP16
- #endif
-#endif
-
-#if defined EIGEN_HAS_CUDA_FP16
- #include <host_defines.h>
- #include <cuda_fp16.h>
+#if defined(EIGEN_HAS_CUDA_BF16) || defined(EIGEN_HAS_HIP_BF16)
+ #define EIGEN_HAS_GPU_BF16
#endif
#if (defined _OPENMP) && (!defined EIGEN_DONT_PARALLELIZE)
@@ -283,7 +83,10 @@
#include <cmath>
#include <cassert>
#include <functional>
-#include <iosfwd>
+#include <sstream>
+#ifndef EIGEN_NO_IO
+ #include <iosfwd>
+#endif
#include <cstring>
#include <string>
#include <limits>
@@ -291,6 +94,10 @@
// for min/max:
#include <algorithm>
+#if EIGEN_HAS_CXX11
+#include <array>
+#endif
+
// for std::is_nothrow_move_assignable
#ifdef EIGEN_INCLUDE_TYPE_TRAITS
#include <type_traits>
@@ -306,38 +113,25 @@
#include <intrin.h>
#endif
-/** \brief Namespace containing all symbols from the %Eigen library. */
-namespace Eigen {
-
-inline static const char *SimdInstructionSetsInUse(void) {
-#if defined(EIGEN_VECTORIZE_AVX512)
- return "AVX512, FMA, AVX2, AVX, SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
-#elif defined(EIGEN_VECTORIZE_AVX)
- return "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
-#elif defined(EIGEN_VECTORIZE_SSE4_2)
- return "SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
-#elif defined(EIGEN_VECTORIZE_SSE4_1)
- return "SSE, SSE2, SSE3, SSSE3, SSE4.1";
-#elif defined(EIGEN_VECTORIZE_SSSE3)
- return "SSE, SSE2, SSE3, SSSE3";
-#elif defined(EIGEN_VECTORIZE_SSE3)
- return "SSE, SSE2, SSE3";
-#elif defined(EIGEN_VECTORIZE_SSE2)
- return "SSE, SSE2";
-#elif defined(EIGEN_VECTORIZE_ALTIVEC)
- return "AltiVec";
-#elif defined(EIGEN_VECTORIZE_VSX)
- return "VSX";
-#elif defined(EIGEN_VECTORIZE_NEON)
- return "ARM NEON";
-#elif defined(EIGEN_VECTORIZE_ZVECTOR)
- return "S390X ZVECTOR";
-#else
- return "None";
+#if defined(EIGEN_USE_SYCL)
+ #undef min
+ #undef max
+ #undef isnan
+ #undef isinf
+ #undef isfinite
+ #include <CL/sycl.hpp>
+ #include <map>
+ #include <memory>
+ #include <utility>
+ #include <thread>
+ #ifndef EIGEN_SYCL_LOCAL_THREAD_DIM0
+ #define EIGEN_SYCL_LOCAL_THREAD_DIM0 16
+ #endif
+ #ifndef EIGEN_SYCL_LOCAL_THREAD_DIM1
+ #define EIGEN_SYCL_LOCAL_THREAD_DIM1 16
+ #endif
#endif
-}
-} // end namespace Eigen
#if defined EIGEN2_SUPPORT_STAGE40_FULL_EIGEN3_STRICTNESS || defined EIGEN2_SUPPORT_STAGE30_FULL_EIGEN3_API || defined EIGEN2_SUPPORT_STAGE20_RESOLVE_API_CONFLICTS || defined EIGEN2_SUPPORT_STAGE10_FULL_EIGEN2_API || defined EIGEN2_SUPPORT
// This will generate an error message:
@@ -346,7 +140,7 @@
namespace Eigen {
-// we use size_t frequently and we'll never remember to prepend it with std:: everytime just to
+// we use size_t frequently and we'll never remember to prepend it with std:: every time just to
// ensure QNX/QCC support
using std::size_t;
// gcc 4.6.0 wants std:: for ptrdiff_t
@@ -370,58 +164,90 @@
#include "src/Core/util/StaticAssert.h"
#include "src/Core/util/XprHelper.h"
#include "src/Core/util/Memory.h"
+#include "src/Core/util/IntegralConstant.h"
+#include "src/Core/util/SymbolicIndex.h"
#include "src/Core/NumTraits.h"
#include "src/Core/MathFunctions.h"
#include "src/Core/GenericPacketMath.h"
#include "src/Core/MathFunctionsImpl.h"
#include "src/Core/arch/Default/ConjHelper.h"
+// Generic half float support
+#include "src/Core/arch/Default/Half.h"
+#include "src/Core/arch/Default/BFloat16.h"
+#include "src/Core/arch/Default/TypeCasting.h"
+#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
#if defined EIGEN_VECTORIZE_AVX512
#include "src/Core/arch/SSE/PacketMath.h"
+ #include "src/Core/arch/SSE/TypeCasting.h"
+ #include "src/Core/arch/SSE/Complex.h"
#include "src/Core/arch/AVX/PacketMath.h"
- #include "src/Core/arch/AVX512/PacketMath.h"
- #include "src/Core/arch/AVX512/MathFunctions.h"
+ #include "src/Core/arch/AVX/TypeCasting.h"
+ #include "src/Core/arch/AVX/Complex.h"
+ // #include "src/Core/arch/AVX512/PacketMath.h"
+ // #include "src/Core/arch/AVX512/TypeCasting.h"
+ // #include "src/Core/arch/AVX512/Complex.h"
+ #include "src/Core/arch/SSE/MathFunctions.h"
+ #include "src/Core/arch/AVX/MathFunctions.h"
+ // #include "src/Core/arch/AVX512/MathFunctions.h"
#elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers
#include "src/Core/arch/SSE/PacketMath.h"
- #include "src/Core/arch/SSE/Complex.h"
- #include "src/Core/arch/SSE/MathFunctions.h"
- #include "src/Core/arch/AVX/PacketMath.h"
- #include "src/Core/arch/AVX/MathFunctions.h"
- #include "src/Core/arch/AVX/Complex.h"
- #include "src/Core/arch/AVX/TypeCasting.h"
#include "src/Core/arch/SSE/TypeCasting.h"
+ #include "src/Core/arch/SSE/Complex.h"
+ #include "src/Core/arch/AVX/PacketMath.h"
+ #include "src/Core/arch/AVX/TypeCasting.h"
+ #include "src/Core/arch/AVX/Complex.h"
+ #include "src/Core/arch/SSE/MathFunctions.h"
+ #include "src/Core/arch/AVX/MathFunctions.h"
#elif defined EIGEN_VECTORIZE_SSE
#include "src/Core/arch/SSE/PacketMath.h"
+ #include "src/Core/arch/SSE/TypeCasting.h"
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/SSE/Complex.h"
- #include "src/Core/arch/SSE/TypeCasting.h"
#elif defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
- #include "src/Core/arch/AltiVec/PacketMath.h"
- #include "src/Core/arch/AltiVec/MathFunctions.h"
- #include "src/Core/arch/AltiVec/Complex.h"
+ // #include "src/Core/arch/AltiVec/PacketMath.h"
+ // #include "src/Core/arch/AltiVec/MathFunctions.h"
+ // #include "src/Core/arch/AltiVec/Complex.h"
#elif defined EIGEN_VECTORIZE_NEON
#include "src/Core/arch/NEON/PacketMath.h"
+ #include "src/Core/arch/NEON/TypeCasting.h"
#include "src/Core/arch/NEON/MathFunctions.h"
#include "src/Core/arch/NEON/Complex.h"
+#elif defined EIGEN_VECTORIZE_SVE
+ // #include "src/Core/arch/SVE/PacketMath.h"
+ // #include "src/Core/arch/SVE/TypeCasting.h"
+ // #include "src/Core/arch/SVE/MathFunctions.h"
#elif defined EIGEN_VECTORIZE_ZVECTOR
- #include "src/Core/arch/ZVector/PacketMath.h"
- #include "src/Core/arch/ZVector/MathFunctions.h"
- #include "src/Core/arch/ZVector/Complex.h"
+ // #include "src/Core/arch/ZVector/PacketMath.h"
+ // #include "src/Core/arch/ZVector/MathFunctions.h"
+ // #include "src/Core/arch/ZVector/Complex.h"
+#elif defined EIGEN_VECTORIZE_MSA
+ // #include "src/Core/arch/MSA/PacketMath.h"
+ // #include "src/Core/arch/MSA/MathFunctions.h"
+ // #include "src/Core/arch/MSA/Complex.h"
#endif
-// Half float support
-// #include "src/Core/arch/CUDA/Half.h"
-// #include "src/Core/arch/CUDA/PacketMathHalf.h"
-// #include "src/Core/arch/CUDA/TypeCasting.h"
+#if defined EIGEN_VECTORIZE_GPU
+ // #include "src/Core/arch/GPU/PacketMath.h"
+ // #include "src/Core/arch/GPU/MathFunctions.h"
+ // #include "src/Core/arch/GPU/TypeCasting.h"
+#endif
-#if defined EIGEN_VECTORIZE_CUDA
- #include "src/Core/arch/CUDA/PacketMath.h"
- #include "src/Core/arch/CUDA/MathFunctions.h"
+#if defined(EIGEN_USE_SYCL)
+ // #include "src/Core/arch/SYCL/SyclMemoryModel.h"
+ // #include "src/Core/arch/SYCL/InteropHeaders.h"
+#if !defined(EIGEN_DONT_VECTORIZE_SYCL)
+ // #include "src/Core/arch/SYCL/PacketMath.h"
+ // #include "src/Core/arch/SYCL/MathFunctions.h"
+ // #include "src/Core/arch/SYCL/TypeCasting.h"
+#endif
#endif
#include "src/Core/arch/Default/Settings.h"
+// This file provides generic implementations valid for scalar as well
+#include "src/Core/arch/Default/GenericPacketMathFunctions.h"
#include "src/Core/functors/TernaryFunctors.h"
#include "src/Core/functors/BinaryFunctors.h"
@@ -432,9 +258,16 @@
// Specialized functors to enable the processing of complex numbers
// on CUDA devices
+#ifdef EIGEN_CUDACC
// #include "src/Core/arch/CUDA/Complex.h"
+#endif
-#include "src/Core/IO.h"
+#include "src/Core/util/IndexedViewHelper.h"
+#include "src/Core/util/ReshapedHelper.h"
+#include "src/Core/ArithmeticSequence.h"
+#ifndef EIGEN_NO_IO
+ #include "src/Core/IO.h"
+#endif
#include "src/Core/DenseCoeffsBase.h"
#include "src/Core/DenseBase.h"
#include "src/Core/MatrixBase.h"
@@ -475,6 +308,8 @@
#include "src/Core/Ref.h"
#include "src/Core/Block.h"
#include "src/Core/VectorBlock.h"
+#include "src/Core/IndexedView.h"
+#include "src/Core/Reshaped.h"
#include "src/Core/Transpose.h"
#include "src/Core/DiagonalMatrix.h"
#include "src/Core/Diagonal.h"
@@ -511,27 +346,35 @@
#include "src/Core/CoreIterators.h"
#include "src/Core/ConditionEstimator.h"
+#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
+ // #include "src/Core/arch/AltiVec/MatrixProduct.h"
+#elif defined EIGEN_VECTORIZE_NEON
+ #include "src/Core/arch/NEON/GeneralBlockPanelKernel.h"
+#endif
+
#include "src/Core/BooleanRedux.h"
#include "src/Core/Select.h"
#include "src/Core/VectorwiseOp.h"
+#include "src/Core/PartialReduxEvaluator.h"
#include "src/Core/Random.h"
#include "src/Core/Replicate.h"
#include "src/Core/Reverse.h"
#include "src/Core/ArrayWrapper.h"
+#include "src/Core/StlIterators.h"
#ifdef EIGEN_USE_BLAS
-#include "src/Core/products/GeneralMatrixMatrix_BLAS.h"
-#include "src/Core/products/GeneralMatrixVector_BLAS.h"
-#include "src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h"
-#include "src/Core/products/SelfadjointMatrixMatrix_BLAS.h"
-#include "src/Core/products/SelfadjointMatrixVector_BLAS.h"
-#include "src/Core/products/TriangularMatrixMatrix_BLAS.h"
-#include "src/Core/products/TriangularMatrixVector_BLAS.h"
-#include "src/Core/products/TriangularSolverMatrix_BLAS.h"
+// #include "src/Core/products/GeneralMatrixMatrix_BLAS.h"
+// #include "src/Core/products/GeneralMatrixVector_BLAS.h"
+// #include "src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h"
+// #include "src/Core/products/SelfadjointMatrixMatrix_BLAS.h"
+// #include "src/Core/products/SelfadjointMatrixVector_BLAS.h"
+// #include "src/Core/products/TriangularMatrixMatrix_BLAS.h"
+// #include "src/Core/products/TriangularMatrixVector_BLAS.h"
+// #include "src/Core/products/TriangularSolverMatrix_BLAS.h"
#endif // EIGEN_USE_BLAS
#ifdef EIGEN_USE_MKL_VML
-#include "src/Core/Assign_MKL.h"
+// #include "src/Core/Assign_MKL.h"
#endif
#include "src/Core/GlobalFunctions.h"
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/Eigen b/wpimath/src/main/native/eigeninclude/Eigen/Eigen
deleted file mode 100644
index 654c8dc..0000000
--- a/wpimath/src/main/native/eigeninclude/Eigen/Eigen
+++ /dev/null
@@ -1,2 +0,0 @@
-#include "Dense"
-#include "Sparse"
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/Eigenvalues b/wpimath/src/main/native/eigeninclude/Eigen/Eigenvalues
index 1ad6bcf..c6defe3 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/Eigenvalues
+++ b/wpimath/src/main/native/eigeninclude/Eigen/Eigenvalues
@@ -10,12 +10,13 @@
#include "Core"
-#include "src/Core/util/DisableStupidWarnings.h"
-
#include "Cholesky"
#include "Jacobi"
#include "Householder"
#include "LU"
+// #include "Geometry"
+
+#include "src/Core/util/DisableStupidWarnings.h"
/** \defgroup Eigenvalues_Module Eigenvalues module
*
@@ -45,16 +46,15 @@
#include "src/Eigenvalues/MatrixBaseEigenvalues.h"
#ifdef EIGEN_USE_LAPACKE
#ifdef EIGEN_USE_MKL
-#include "mkl_lapacke.h"
+// #include "mkl_lapacke.h"
#else
-#include "src/misc/lapacke.h"
+// #include "src/misc/lapacke.h"
#endif
-#include "src/Eigenvalues/RealSchur_LAPACKE.h"
-#include "src/Eigenvalues/ComplexSchur_LAPACKE.h"
-#include "src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h"
+// #include "src/Eigenvalues/RealSchur_LAPACKE.h"
+// #include "src/Eigenvalues/ComplexSchur_LAPACKE.h"
+// #include "src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h"
#endif
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_EIGENVALUES_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/Householder b/wpimath/src/main/native/eigeninclude/Eigen/Householder
index 89cd81b..f2fa799 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/Householder
+++ b/wpimath/src/main/native/eigeninclude/Eigen/Householder
@@ -27,4 +27,3 @@
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_HOUSEHOLDER_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/Jacobi b/wpimath/src/main/native/eigeninclude/Eigen/Jacobi
index 17c1d78..43edc7a 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/Jacobi
+++ b/wpimath/src/main/native/eigeninclude/Eigen/Jacobi
@@ -29,5 +29,4 @@
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_JACOBI_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/LU b/wpimath/src/main/native/eigeninclude/Eigen/LU
index 6418a86..a1b5d46 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/LU
+++ b/wpimath/src/main/native/eigeninclude/Eigen/LU
@@ -29,22 +29,19 @@
#include "src/LU/PartialPivLU.h"
#ifdef EIGEN_USE_LAPACKE
#ifdef EIGEN_USE_MKL
-#include "mkl_lapacke.h"
+// #include "mkl_lapacke.h"
#else
-#include "src/misc/lapacke.h"
+// #include "src/misc/lapacke.h"
#endif
-#include "src/LU/PartialPivLU_LAPACKE.h"
+// #include "src/LU/PartialPivLU_LAPACKE.h"
#endif
#include "src/LU/Determinant.h"
#include "src/LU/InverseImpl.h"
-// Use the SSE optimized version whenever possible. At the moment the
-// SSE version doesn't compile when AVX is enabled
-#if defined EIGEN_VECTORIZE_SSE && !defined EIGEN_VECTORIZE_AVX
- #include "src/LU/arch/Inverse_SSE.h"
+#if defined EIGEN_VECTORIZE_SSE || defined EIGEN_VECTORIZE_NEON
+ #include "src/LU/arch/InverseSize4.h"
#endif
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_LU_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/QR b/wpimath/src/main/native/eigeninclude/Eigen/QR
index c7e9144..42a3fa8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/QR
+++ b/wpimath/src/main/native/eigeninclude/Eigen/QR
@@ -10,12 +10,12 @@
#include "Core"
-#include "src/Core/util/DisableStupidWarnings.h"
-
#include "Cholesky"
#include "Jacobi"
#include "Householder"
+#include "src/Core/util/DisableStupidWarnings.h"
+
/** \defgroup QR_Module QR module
*
*
@@ -37,15 +37,14 @@
#include "src/QR/CompleteOrthogonalDecomposition.h"
#ifdef EIGEN_USE_LAPACKE
#ifdef EIGEN_USE_MKL
-#include "mkl_lapacke.h"
+// #include "mkl_lapacke.h"
#else
-#include "src/misc/lapacke.h"
+// #include "src/misc/lapacke.h"
#endif
-#include "src/QR/HouseholderQR_LAPACKE.h"
-#include "src/QR/ColPivHouseholderQR_LAPACKE.h"
+// #include "src/QR/HouseholderQR_LAPACKE.h"
+// #include "src/QR/ColPivHouseholderQR_LAPACKE.h"
#endif
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_QR_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/SVD b/wpimath/src/main/native/eigeninclude/Eigen/SVD
index 5d0e75f..4441a38 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/SVD
+++ b/wpimath/src/main/native/eigeninclude/Eigen/SVD
@@ -38,14 +38,13 @@
#include "src/SVD/BDCSVD.h"
#if defined(EIGEN_USE_LAPACKE) && !defined(EIGEN_USE_LAPACKE_STRICT)
#ifdef EIGEN_USE_MKL
-#include "mkl_lapacke.h"
+// #include "mkl_lapacke.h"
#else
-#include "src/misc/lapacke.h"
+// #include "src/misc/lapacke.h"
#endif
-#include "src/SVD/JacobiSVD_LAPACKE.h"
+// #include "src/SVD/JacobiSVD_LAPACKE.h"
#endif
#include "src/Core/util/ReenableStupidWarnings.h"
#endif // EIGEN_SVD_MODULE_H
-/* vim: set filetype=cpp et sw=2 ts=2 ai: */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LDLT.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LDLT.h
index 15ccf24..1013ca0 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LDLT.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LDLT.h
@@ -16,6 +16,15 @@
namespace Eigen {
namespace internal {
+ template<typename _MatrixType, int _UpLo> struct traits<LDLT<_MatrixType, _UpLo> >
+ : traits<_MatrixType>
+ {
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
+ enum { Flags = 0 };
+ };
+
template<typename MatrixType, int UpLo> struct LDLT_Traits;
// PositiveSemiDef means positive semi-definite and non-zero; same for NegativeSemiDef
@@ -36,7 +45,7 @@
* matrix \f$ A \f$ such that \f$ A = P^TLDL^*P \f$, where P is a permutation matrix, L
* is lower triangular with a unit diagonal and D is a diagonal matrix.
*
- * The decomposition uses pivoting to ensure stability, so that L will have
+ * The decomposition uses pivoting to ensure stability, so that D will have
* zeros in the bottom right rank(A) - n submatrix. Avoiding the square root
* on D also stabilizes the computation.
*
@@ -44,24 +53,23 @@
* decomposition to determine whether a system of equations has a solution.
*
* This class supports the \link InplaceDecomposition inplace decomposition \endlink mechanism.
- *
+ *
* \sa MatrixBase::ldlt(), SelfAdjointView::ldlt(), class LLT
*/
template<typename _MatrixType, int _UpLo> class LDLT
+ : public SolverBase<LDLT<_MatrixType, _UpLo> >
{
public:
typedef _MatrixType MatrixType;
+ typedef SolverBase<LDLT> Base;
+ friend class SolverBase<LDLT>;
+
+ EIGEN_GENERIC_PUBLIC_INTERFACE(LDLT)
enum {
- RowsAtCompileTime = MatrixType::RowsAtCompileTime,
- ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
UpLo = _UpLo
};
- typedef typename MatrixType::Scalar Scalar;
- typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
- typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
- typedef typename MatrixType::StorageIndex StorageIndex;
typedef Matrix<Scalar, RowsAtCompileTime, 1, 0, MaxRowsAtCompileTime, 1> TmpMatrixType;
typedef Transpositions<RowsAtCompileTime, MaxRowsAtCompileTime> TranspositionType;
@@ -180,6 +188,7 @@
return m_sign == internal::NegativeSemiDef || m_sign == internal::ZeroSign;
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** \returns a solution x of \f$ A x = b \f$ using the current decomposition of A.
*
* This function also supports in-place solves using the syntax <tt>x = decompositionObject.solve(x)</tt> .
@@ -191,19 +200,14 @@
* \f$ L^* y_4 = y_3 \f$ and \f$ P x = y_4 \f$ in succession. If the matrix \f$ A \f$ is singular, then
* \f$ D \f$ will also be singular (all the other matrices are invertible). In that case, the
* least-square solution of \f$ D y_3 = y_2 \f$ is computed. This does not mean that this function
- * computes the least-square solution of \f$ A x = b \f$ is \f$ A \f$ is singular.
+ * computes the least-square solution of \f$ A x = b \f$ if \f$ A \f$ is singular.
*
* \sa MatrixBase::ldlt(), SelfAdjointView::ldlt()
*/
template<typename Rhs>
inline const Solve<LDLT, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "LDLT is not initialized.");
- eigen_assert(m_matrix.rows()==b.rows()
- && "LDLT::solve(): invalid number of rows of the right hand side matrix b");
- return Solve<LDLT, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
template<typename Derived>
bool solveInPlace(MatrixBase<Derived> &bAndX) const;
@@ -242,12 +246,12 @@
*/
const LDLT& adjoint() const { return *this; };
- inline Index rows() const { return m_matrix.rows(); }
- inline Index cols() const { return m_matrix.cols(); }
+ EIGEN_DEVICE_FUNC inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
+ EIGEN_DEVICE_FUNC inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful,
+ * \returns \c Success if computation was successful,
* \c NumericalIssue if the factorization failed because of a zero pivot.
*/
ComputationInfo info() const
@@ -258,8 +262,10 @@
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
@@ -560,14 +566,22 @@
template<typename RhsType, typename DstType>
void LDLT<_MatrixType,_UpLo>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
- eigen_assert(rhs.rows() == rows());
+ _solve_impl_transposed<true>(rhs, dst);
+}
+
+template<typename _MatrixType,int _UpLo>
+template<bool Conjugate, typename RhsType, typename DstType>
+void LDLT<_MatrixType,_UpLo>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
// dst = P b
dst = m_transpositions * rhs;
// dst = L^-1 (P b)
- matrixL().solveInPlace(dst);
+ // dst = L^-*T (P b)
+ matrixL().template conjugateIf<!Conjugate>().solveInPlace(dst);
- // dst = D^-1 (L^-1 P b)
+ // dst = D^-* (L^-1 P b)
+ // dst = D^-1 (L^-*T P b)
// more precisely, use pseudo-inverse of D (see bug 241)
using std::abs;
const typename Diagonal<const MatrixType>::RealReturnType vecD(vectorD());
@@ -579,7 +593,6 @@
// Moreover, Lapack's xSYTRS routines use 0 for the tolerance.
// Using numeric_limits::min() gives us more robustness to denormals.
RealScalar tolerance = (std::numeric_limits<RealScalar>::min)();
-
for (Index i = 0; i < vecD.size(); ++i)
{
if(abs(vecD(i)) > tolerance)
@@ -588,10 +601,12 @@
dst.row(i).setZero();
}
- // dst = L^-T (D^-1 L^-1 P b)
- matrixU().solveInPlace(dst);
+ // dst = L^-* (D^-* L^-1 P b)
+ // dst = L^-T (D^-1 L^-*T P b)
+ matrixL().transpose().template conjugateIf<Conjugate>().solveInPlace(dst);
- // dst = P^-1 (L^-T D^-1 L^-1 P b) = A^-1 b
+ // dst = P^T (L^-* D^-* L^-1 P b) = A^-1 b
+ // dst = P^-T (L^-T D^-1 L^-*T P b) = A^-1 b
dst = m_transpositions.transpose() * dst;
}
#endif
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LLT.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LLT.h
index e1624d2..8c9b2b3 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LLT.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Cholesky/LLT.h
@@ -13,6 +13,16 @@
namespace Eigen {
namespace internal{
+
+template<typename _MatrixType, int _UpLo> struct traits<LLT<_MatrixType, _UpLo> >
+ : traits<_MatrixType>
+{
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
+ enum { Flags = 0 };
+};
+
template<typename MatrixType, int UpLo> struct LLT_Traits;
}
@@ -54,18 +64,17 @@
* \sa MatrixBase::llt(), SelfAdjointView::llt(), class LDLT
*/
template<typename _MatrixType, int _UpLo> class LLT
+ : public SolverBase<LLT<_MatrixType, _UpLo> >
{
public:
typedef _MatrixType MatrixType;
+ typedef SolverBase<LLT> Base;
+ friend class SolverBase<LLT>;
+
+ EIGEN_GENERIC_PUBLIC_INTERFACE(LLT)
enum {
- RowsAtCompileTime = MatrixType::RowsAtCompileTime,
- ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
- typedef typename MatrixType::Scalar Scalar;
- typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
- typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
- typedef typename MatrixType::StorageIndex StorageIndex;
enum {
PacketSize = internal::packet_traits<Scalar>::size,
@@ -100,7 +109,7 @@
compute(matrix.derived());
}
- /** \brief Constructs a LDLT factorization from a given matrix
+ /** \brief Constructs a LLT factorization from a given matrix
*
* This overloaded constructor is provided for \link InplaceDecomposition inplace decomposition \endlink when
* \c MatrixType is a Eigen::Ref.
@@ -129,6 +138,7 @@
return Traits::getL(m_matrix);
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A.
*
* Since this LLT class assumes anyway that the matrix A is invertible, the solution
@@ -141,13 +151,8 @@
*/
template<typename Rhs>
inline const Solve<LLT, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "LLT is not initialized.");
- eigen_assert(m_matrix.rows()==b.rows()
- && "LLT::solve(): invalid number of rows of the right hand side matrix b");
- return Solve<LLT, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
template<typename Derived>
void solveInPlace(const MatrixBase<Derived> &bAndX) const;
@@ -180,7 +185,7 @@
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful,
+ * \returns \c Success if computation was successful,
* \c NumericalIssue if the matrix.appears not to be positive definite.
*/
ComputationInfo info() const
@@ -194,18 +199,20 @@
* This method is provided for compatibility with other matrix decompositions, thus enabling generic code such as:
* \code x = decomposition.adjoint().solve(b) \endcode
*/
- const LLT& adjoint() const { return *this; };
+ const LLT& adjoint() const EIGEN_NOEXCEPT { return *this; };
- inline Index rows() const { return m_matrix.rows(); }
- inline Index cols() const { return m_matrix.cols(); }
+ inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
+ inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
template<typename VectorType>
- LLT rankUpdate(const VectorType& vec, const RealScalar& sigma = 1);
+ LLT & rankUpdate(const VectorType& vec, const RealScalar& sigma = 1);
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
@@ -459,7 +466,7 @@
*/
template<typename _MatrixType, int _UpLo>
template<typename VectorType>
-LLT<_MatrixType,_UpLo> LLT<_MatrixType,_UpLo>::rankUpdate(const VectorType& v, const RealScalar& sigma)
+LLT<_MatrixType,_UpLo> & LLT<_MatrixType,_UpLo>::rankUpdate(const VectorType& v, const RealScalar& sigma)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(VectorType);
eigen_assert(v.size()==m_matrix.cols());
@@ -477,8 +484,17 @@
template<typename RhsType, typename DstType>
void LLT<_MatrixType,_UpLo>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
- dst = rhs;
- solveInPlace(dst);
+ _solve_impl_transposed<true>(rhs, dst);
+}
+
+template<typename _MatrixType,int _UpLo>
+template<bool Conjugate, typename RhsType, typename DstType>
+void LLT<_MatrixType,_UpLo>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
+ dst = rhs;
+
+ matrixL().template conjugateIf<!Conjugate>().solveInPlace(dst);
+ matrixU().template conjugateIf<!Conjugate>().solveInPlace(dst);
}
#endif
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArithmeticSequence.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArithmeticSequence.h
new file mode 100644
index 0000000..b6200fa
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArithmeticSequence.h
@@ -0,0 +1,413 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_ARITHMETIC_SEQUENCE_H
+#define EIGEN_ARITHMETIC_SEQUENCE_H
+
+namespace Eigen {
+
+namespace internal {
+
+#if (!EIGEN_HAS_CXX11) || !((!EIGEN_COMP_GNUC) || EIGEN_COMP_GNUC>=48)
+template<typename T> struct aseq_negate {};
+
+template<> struct aseq_negate<Index> {
+ typedef Index type;
+};
+
+template<int N> struct aseq_negate<FixedInt<N> > {
+ typedef FixedInt<-N> type;
+};
+
+// Compilation error in the following case:
+template<> struct aseq_negate<FixedInt<DynamicIndex> > {};
+
+template<typename FirstType,typename SizeType,typename IncrType,
+ bool FirstIsSymbolic=symbolic::is_symbolic<FirstType>::value,
+ bool SizeIsSymbolic =symbolic::is_symbolic<SizeType>::value>
+struct aseq_reverse_first_type {
+ typedef Index type;
+};
+
+template<typename FirstType,typename SizeType,typename IncrType>
+struct aseq_reverse_first_type<FirstType,SizeType,IncrType,true,true> {
+ typedef symbolic::AddExpr<FirstType,
+ symbolic::ProductExpr<symbolic::AddExpr<SizeType,symbolic::ValueExpr<FixedInt<-1> > >,
+ symbolic::ValueExpr<IncrType> >
+ > type;
+};
+
+template<typename SizeType,typename IncrType,typename EnableIf = void>
+struct aseq_reverse_first_type_aux {
+ typedef Index type;
+};
+
+template<typename SizeType,typename IncrType>
+struct aseq_reverse_first_type_aux<SizeType,IncrType,typename internal::enable_if<bool((SizeType::value+IncrType::value)|0x1)>::type> {
+ typedef FixedInt<(SizeType::value-1)*IncrType::value> type;
+};
+
+template<typename FirstType,typename SizeType,typename IncrType>
+struct aseq_reverse_first_type<FirstType,SizeType,IncrType,true,false> {
+ typedef typename aseq_reverse_first_type_aux<SizeType,IncrType>::type Aux;
+ typedef symbolic::AddExpr<FirstType,symbolic::ValueExpr<Aux> > type;
+};
+
+template<typename FirstType,typename SizeType,typename IncrType>
+struct aseq_reverse_first_type<FirstType,SizeType,IncrType,false,true> {
+ typedef symbolic::AddExpr<symbolic::ProductExpr<symbolic::AddExpr<SizeType,symbolic::ValueExpr<FixedInt<-1> > >,
+ symbolic::ValueExpr<IncrType> >,
+ symbolic::ValueExpr<> > type;
+};
+#endif
+
+// Helper to cleanup the type of the increment:
+template<typename T> struct cleanup_seq_incr {
+ typedef typename cleanup_index_type<T,DynamicIndex>::type type;
+};
+
+}
+
+//--------------------------------------------------------------------------------
+// seq(first,last,incr) and seqN(first,size,incr)
+//--------------------------------------------------------------------------------
+
+template<typename FirstType=Index,typename SizeType=Index,typename IncrType=internal::FixedInt<1> >
+class ArithmeticSequence;
+
+template<typename FirstType,typename SizeType,typename IncrType>
+ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,
+ typename internal::cleanup_index_type<SizeType>::type,
+ typename internal::cleanup_seq_incr<IncrType>::type >
+seqN(FirstType first, SizeType size, IncrType incr);
+
+/** \class ArithmeticSequence
+ * \ingroup Core_Module
+ *
+ * This class represents an arithmetic progression \f$ a_0, a_1, a_2, ..., a_{n-1}\f$ defined by
+ * its \em first value \f$ a_0 \f$, its \em size (aka length) \em n, and the \em increment (aka stride)
+ * that is equal to \f$ a_{i+1}-a_{i}\f$ for any \em i.
+ *
+ * It is internally used as the return type of the Eigen::seq and Eigen::seqN functions, and as the input arguments
+ * of DenseBase::operator()(const RowIndices&, const ColIndices&), and most of the time this is the
+ * only way it is used.
+ *
+ * \tparam FirstType type of the first element, usually an Index,
+ * but internally it can be a symbolic expression
+ * \tparam SizeType type representing the size of the sequence, usually an Index
+ * or a compile time integral constant. Internally, it can also be a symbolic expression
+ * \tparam IncrType type of the increment, can be a runtime Index, or a compile time integral constant (default is compile-time 1)
+ *
+ * \sa Eigen::seq, Eigen::seqN, DenseBase::operator()(const RowIndices&, const ColIndices&), class IndexedView
+ */
+template<typename FirstType,typename SizeType,typename IncrType>
+class ArithmeticSequence
+{
+public:
+ ArithmeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {}
+ ArithmeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {}
+
+ enum {
+ SizeAtCompileTime = internal::get_fixed_value<SizeType>::value,
+ IncrAtCompileTime = internal::get_fixed_value<IncrType,DynamicIndex>::value
+ };
+
+ /** \returns the size, i.e., number of elements, of the sequence */
+ Index size() const { return m_size; }
+
+ /** \returns the first element \f$ a_0 \f$ in the sequence */
+ Index first() const { return m_first; }
+
+ /** \returns the value \f$ a_i \f$ at index \a i in the sequence. */
+ Index operator[](Index i) const { return m_first + i * m_incr; }
+
+ const FirstType& firstObject() const { return m_first; }
+ const SizeType& sizeObject() const { return m_size; }
+ const IncrType& incrObject() const { return m_incr; }
+
+protected:
+ FirstType m_first;
+ SizeType m_size;
+ IncrType m_incr;
+
+public:
+
+#if EIGEN_HAS_CXX11 && ((!EIGEN_COMP_GNUC) || EIGEN_COMP_GNUC>=48)
+ auto reverse() const -> decltype(Eigen::seqN(m_first+(m_size+fix<-1>())*m_incr,m_size,-m_incr)) {
+ return seqN(m_first+(m_size+fix<-1>())*m_incr,m_size,-m_incr);
+ }
+#else
+protected:
+ typedef typename internal::aseq_negate<IncrType>::type ReverseIncrType;
+ typedef typename internal::aseq_reverse_first_type<FirstType,SizeType,IncrType>::type ReverseFirstType;
+public:
+ ArithmeticSequence<ReverseFirstType,SizeType,ReverseIncrType>
+ reverse() const {
+ return seqN(m_first+(m_size+fix<-1>())*m_incr,m_size,-m_incr);
+ }
+#endif
+};
+
+/** \returns an ArithmeticSequence starting at \a first, of length \a size, and increment \a incr
+ *
+ * \sa seqN(FirstType,SizeType), seq(FirstType,LastType,IncrType) */
+template<typename FirstType,typename SizeType,typename IncrType>
+ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type,typename internal::cleanup_seq_incr<IncrType>::type >
+seqN(FirstType first, SizeType size, IncrType incr) {
+ return ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type,typename internal::cleanup_seq_incr<IncrType>::type>(first,size,incr);
+}
+
+/** \returns an ArithmeticSequence starting at \a first, of length \a size, and unit increment
+ *
+ * \sa seqN(FirstType,SizeType,IncrType), seq(FirstType,LastType) */
+template<typename FirstType,typename SizeType>
+ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type >
+seqN(FirstType first, SizeType size) {
+ return ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,typename internal::cleanup_index_type<SizeType>::type>(first,size);
+}
+
+#ifdef EIGEN_PARSED_BY_DOXYGEN
+
+/** \returns an ArithmeticSequence starting at \a f, up (or down) to \a l, and with positive (or negative) increment \a incr
+ *
+ * It is essentially an alias to:
+ * \code
+ * seqN(f, (l-f+incr)/incr, incr);
+ * \endcode
+ *
+ * \sa seqN(FirstType,SizeType,IncrType), seq(FirstType,LastType)
+ */
+template<typename FirstType,typename LastType, typename IncrType>
+auto seq(FirstType f, LastType l, IncrType incr);
+
+/** \returns an ArithmeticSequence starting at \a f, up (or down) to \a l, and unit increment
+ *
+ * It is essentially an alias to:
+ * \code
+ * seqN(f,l-f+1);
+ * \endcode
+ *
+ * \sa seqN(FirstType,SizeType), seq(FirstType,LastType,IncrType)
+ */
+template<typename FirstType,typename LastType>
+auto seq(FirstType f, LastType l);
+
+#else // EIGEN_PARSED_BY_DOXYGEN
+
+#if EIGEN_HAS_CXX11
+template<typename FirstType,typename LastType>
+auto seq(FirstType f, LastType l) -> decltype(seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ ( typename internal::cleanup_index_type<LastType>::type(l)
+ - typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>())))
+{
+ return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ (typename internal::cleanup_index_type<LastType>::type(l)
+ -typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>()));
+}
+
+template<typename FirstType,typename LastType, typename IncrType>
+auto seq(FirstType f, LastType l, IncrType incr)
+ -> decltype(seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ ( typename internal::cleanup_index_type<LastType>::type(l)
+ - typename internal::cleanup_index_type<FirstType>::type(f)+typename internal::cleanup_seq_incr<IncrType>::type(incr)
+ ) / typename internal::cleanup_seq_incr<IncrType>::type(incr),
+ typename internal::cleanup_seq_incr<IncrType>::type(incr)))
+{
+ typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
+ return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ ( typename internal::cleanup_index_type<LastType>::type(l)
+ -typename internal::cleanup_index_type<FirstType>::type(f)+CleanedIncrType(incr)) / CleanedIncrType(incr),
+ CleanedIncrType(incr));
+}
+
+#else // EIGEN_HAS_CXX11
+
+template<typename FirstType,typename LastType>
+typename internal::enable_if<!(symbolic::is_symbolic<FirstType>::value || symbolic::is_symbolic<LastType>::value),
+ ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,Index> >::type
+seq(FirstType f, LastType l)
+{
+ return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ Index((typename internal::cleanup_index_type<LastType>::type(l)-typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>())));
+}
+
+template<typename FirstTypeDerived,typename LastType>
+typename internal::enable_if<!symbolic::is_symbolic<LastType>::value,
+ ArithmeticSequence<FirstTypeDerived, symbolic::AddExpr<symbolic::AddExpr<symbolic::NegateExpr<FirstTypeDerived>,symbolic::ValueExpr<> >,
+ symbolic::ValueExpr<internal::FixedInt<1> > > > >::type
+seq(const symbolic::BaseExpr<FirstTypeDerived> &f, LastType l)
+{
+ return seqN(f.derived(),(typename internal::cleanup_index_type<LastType>::type(l)-f.derived()+fix<1>()));
+}
+
+template<typename FirstType,typename LastTypeDerived>
+typename internal::enable_if<!symbolic::is_symbolic<FirstType>::value,
+ ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,
+ symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,symbolic::ValueExpr<> >,
+ symbolic::ValueExpr<internal::FixedInt<1> > > > >::type
+seq(FirstType f, const symbolic::BaseExpr<LastTypeDerived> &l)
+{
+ return seqN(typename internal::cleanup_index_type<FirstType>::type(f),(l.derived()-typename internal::cleanup_index_type<FirstType>::type(f)+fix<1>()));
+}
+
+template<typename FirstTypeDerived,typename LastTypeDerived>
+ArithmeticSequence<FirstTypeDerived,
+ symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,symbolic::NegateExpr<FirstTypeDerived> >,symbolic::ValueExpr<internal::FixedInt<1> > > >
+seq(const symbolic::BaseExpr<FirstTypeDerived> &f, const symbolic::BaseExpr<LastTypeDerived> &l)
+{
+ return seqN(f.derived(),(l.derived()-f.derived()+fix<1>()));
+}
+
+
+template<typename FirstType,typename LastType, typename IncrType>
+typename internal::enable_if<!(symbolic::is_symbolic<FirstType>::value || symbolic::is_symbolic<LastType>::value),
+ ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,Index,typename internal::cleanup_seq_incr<IncrType>::type> >::type
+seq(FirstType f, LastType l, IncrType incr)
+{
+ typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
+ return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ Index((typename internal::cleanup_index_type<LastType>::type(l)-typename internal::cleanup_index_type<FirstType>::type(f)+CleanedIncrType(incr))/CleanedIncrType(incr)), incr);
+}
+
+template<typename FirstTypeDerived,typename LastType, typename IncrType>
+typename internal::enable_if<!symbolic::is_symbolic<LastType>::value,
+ ArithmeticSequence<FirstTypeDerived,
+ symbolic::QuotientExpr<symbolic::AddExpr<symbolic::AddExpr<symbolic::NegateExpr<FirstTypeDerived>,
+ symbolic::ValueExpr<> >,
+ symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
+ symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
+ typename internal::cleanup_seq_incr<IncrType>::type> >::type
+seq(const symbolic::BaseExpr<FirstTypeDerived> &f, LastType l, IncrType incr)
+{
+ typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
+ return seqN(f.derived(),(typename internal::cleanup_index_type<LastType>::type(l)-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
+}
+
+template<typename FirstType,typename LastTypeDerived, typename IncrType>
+typename internal::enable_if<!symbolic::is_symbolic<FirstType>::value,
+ ArithmeticSequence<typename internal::cleanup_index_type<FirstType>::type,
+ symbolic::QuotientExpr<symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,symbolic::ValueExpr<> >,
+ symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
+ symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
+ typename internal::cleanup_seq_incr<IncrType>::type> >::type
+seq(FirstType f, const symbolic::BaseExpr<LastTypeDerived> &l, IncrType incr)
+{
+ typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
+ return seqN(typename internal::cleanup_index_type<FirstType>::type(f),
+ (l.derived()-typename internal::cleanup_index_type<FirstType>::type(f)+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
+}
+
+template<typename FirstTypeDerived,typename LastTypeDerived, typename IncrType>
+ArithmeticSequence<FirstTypeDerived,
+ symbolic::QuotientExpr<symbolic::AddExpr<symbolic::AddExpr<LastTypeDerived,
+ symbolic::NegateExpr<FirstTypeDerived> >,
+ symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
+ symbolic::ValueExpr<typename internal::cleanup_seq_incr<IncrType>::type> >,
+ typename internal::cleanup_seq_incr<IncrType>::type>
+seq(const symbolic::BaseExpr<FirstTypeDerived> &f, const symbolic::BaseExpr<LastTypeDerived> &l, IncrType incr)
+{
+ typedef typename internal::cleanup_seq_incr<IncrType>::type CleanedIncrType;
+ return seqN(f.derived(),(l.derived()-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
+}
+#endif // EIGEN_HAS_CXX11
+
+#endif // EIGEN_PARSED_BY_DOXYGEN
+
+
+#if EIGEN_HAS_CXX11 || defined(EIGEN_PARSED_BY_DOXYGEN)
+/** \cpp11
+ * \returns a symbolic ArithmeticSequence representing the last \a size elements with increment \a incr.
+ *
+ * It is a shortcut for: \code seqN(last-(size-fix<1>)*incr, size, incr) \endcode
+ *
+ * \sa lastN(SizeType), seqN(FirstType,SizeType), seq(FirstType,LastType,IncrType) */
+template<typename SizeType,typename IncrType>
+auto lastN(SizeType size, IncrType incr)
+-> decltype(seqN(Eigen::last-(size-fix<1>())*incr, size, incr))
+{
+ return seqN(Eigen::last-(size-fix<1>())*incr, size, incr);
+}
+
+/** \cpp11
+ * \returns a symbolic ArithmeticSequence representing the last \a size elements with a unit increment.
+ *
+ * It is a shortcut for: \code seq(last+fix<1>-size, last) \endcode
+ *
+ * \sa lastN(SizeType,IncrType, seqN(FirstType,SizeType), seq(FirstType,LastType) */
+template<typename SizeType>
+auto lastN(SizeType size)
+-> decltype(seqN(Eigen::last+fix<1>()-size, size))
+{
+ return seqN(Eigen::last+fix<1>()-size, size);
+}
+#endif
+
+namespace internal {
+
+// Convert a symbolic span into a usable one (i.e., remove last/end "keywords")
+template<typename T>
+struct make_size_type {
+ typedef typename internal::conditional<symbolic::is_symbolic<T>::value, Index, T>::type type;
+};
+
+template<typename FirstType,typename SizeType,typename IncrType,int XprSize>
+struct IndexedViewCompatibleType<ArithmeticSequence<FirstType,SizeType,IncrType>, XprSize> {
+ typedef ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> type;
+};
+
+template<typename FirstType,typename SizeType,typename IncrType>
+ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>
+makeIndexedViewCompatible(const ArithmeticSequence<FirstType,SizeType,IncrType>& ids, Index size,SpecializedType) {
+ return ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>(
+ eval_expr_given_size(ids.firstObject(),size),eval_expr_given_size(ids.sizeObject(),size),ids.incrObject());
+}
+
+template<typename FirstType,typename SizeType,typename IncrType>
+struct get_compile_time_incr<ArithmeticSequence<FirstType,SizeType,IncrType> > {
+ enum { value = get_fixed_value<IncrType,DynamicIndex>::value };
+};
+
+} // end namespace internal
+
+/** \namespace Eigen::indexing
+ * \ingroup Core_Module
+ *
+ * The sole purpose of this namespace is to be able to import all functions
+ * and symbols that are expected to be used within operator() for indexing
+ * and slicing. If you already imported the whole Eigen namespace:
+ * \code using namespace Eigen; \endcode
+ * then you are already all set. Otherwise, if you don't want/cannot import
+ * the whole Eigen namespace, the following line:
+ * \code using namespace Eigen::indexing; \endcode
+ * is equivalent to:
+ * \code
+ using Eigen::all;
+ using Eigen::seq;
+ using Eigen::seqN;
+ using Eigen::lastN; // c++11 only
+ using Eigen::last;
+ using Eigen::lastp1;
+ using Eigen::fix;
+ \endcode
+ */
+namespace indexing {
+ using Eigen::all;
+ using Eigen::seq;
+ using Eigen::seqN;
+ #if EIGEN_HAS_CXX11
+ using Eigen::lastN;
+ #endif
+ using Eigen::last;
+ using Eigen::lastp1;
+ using Eigen::fix;
+}
+
+} // end namespace Eigen
+
+#endif // EIGEN_ARITHMETIC_SEQUENCE_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Array.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Array.h
index 16770fc..20c789b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Array.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Array.h
@@ -117,7 +117,7 @@
{
return Base::_set(other);
}
-
+
/** Default constructor.
*
* For fixed-size matrices, does nothing.
@@ -157,11 +157,50 @@
EIGEN_DEVICE_FUNC
Array& operator=(Array&& other) EIGEN_NOEXCEPT_IF(std::is_nothrow_move_assignable<Scalar>::value)
{
- other.swap(*this);
+ Base::operator=(std::move(other));
return *this;
}
#endif
+ #if EIGEN_HAS_CXX11
+ /** \copydoc PlainObjectBase(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ *
+ * Example: \include Array_variadic_ctor_cxx11.cpp
+ * Output: \verbinclude Array_variadic_ctor_cxx11.out
+ *
+ * \sa Array(const std::initializer_list<std::initializer_list<Scalar>>&)
+ * \sa Array(const Scalar&), Array(const Scalar&,const Scalar&)
+ */
+ template <typename... ArgTypes>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Array(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ : Base(a0, a1, a2, a3, args...) {}
+
+ /** \brief Constructs an array and initializes it from the coefficients given as initializer-lists grouped by row. \cpp11
+ *
+ * In the general case, the constructor takes a list of rows, each row being represented as a list of coefficients:
+ *
+ * Example: \include Array_initializer_list_23_cxx11.cpp
+ * Output: \verbinclude Array_initializer_list_23_cxx11.out
+ *
+ * Each of the inner initializer lists must contain the exact same number of elements, otherwise an assertion is triggered.
+ *
+ * In the case of a compile-time column 1D array, implicit transposition from a single row is allowed.
+ * Therefore <code> Array<int,Dynamic,1>{{1,2,3,4,5}}</code> is legal and the more verbose syntax
+ * <code>Array<int,Dynamic,1>{{1},{2},{3},{4},{5}}</code> can be avoided:
+ *
+ * Example: \include Array_initializer_list_vector_cxx11.cpp
+ * Output: \verbinclude Array_initializer_list_vector_cxx11.out
+ *
+ * In the case of fixed-sized arrays, the initializer list sizes must exactly match the array sizes,
+ * and implicit transposition is allowed for compile-time 1D arrays only.
+ *
+ * \sa Array(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ */
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Array(const std::initializer_list<std::initializer_list<Scalar>>& list) : Base(list) {}
+ #endif // end EIGEN_HAS_CXX11
+
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename T>
EIGEN_DEVICE_FUNC
@@ -178,6 +217,7 @@
Base::_check_template_params();
this->template _init2<T0,T1>(val0, val1);
}
+
#else
/** \brief Constructs a fixed-sized array initialized with coefficients starting at \a data */
EIGEN_DEVICE_FUNC explicit Array(const Scalar *data);
@@ -189,7 +229,8 @@
*/
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE explicit Array(Index dim);
- /** constructs an initialized 1x1 Array with the given coefficient */
+ /** constructs an initialized 1x1 Array with the given coefficient
+ * \sa const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args */
Array(const Scalar& value);
/** constructs an uninitialized array with \a rows rows and \a cols columns.
*
@@ -197,11 +238,14 @@
* it is redundant to pass these parameters, so one should use the default constructor
* Array() instead. */
Array(Index rows, Index cols);
- /** constructs an initialized 2D vector with given coefficients */
+ /** constructs an initialized 2D vector with given coefficients
+ * \sa Array(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args) */
Array(const Scalar& val0, const Scalar& val1);
- #endif
+ #endif // end EIGEN_PARSED_BY_DOXYGEN
- /** constructs an initialized 3D vector with given coefficients */
+ /** constructs an initialized 3D vector with given coefficients
+ * \sa Array(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ */
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Array(const Scalar& val0, const Scalar& val1, const Scalar& val2)
{
@@ -211,7 +255,9 @@
m_storage.data()[1] = val1;
m_storage.data()[2] = val2;
}
- /** constructs an initialized 4D vector with given coefficients */
+ /** constructs an initialized 4D vector with given coefficients
+ * \sa Array(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ */
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Array(const Scalar& val0, const Scalar& val1, const Scalar& val2, const Scalar& val3)
{
@@ -242,8 +288,10 @@
: Base(other.derived())
{ }
- EIGEN_DEVICE_FUNC inline Index innerStride() const { return 1; }
- EIGEN_DEVICE_FUNC inline Index outerStride() const { return this->innerSize(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT{ return 1; }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return this->innerSize(); }
#ifdef EIGEN_ARRAY_PLUGIN
#include EIGEN_ARRAY_PLUGIN
@@ -258,7 +306,7 @@
/** \defgroup arraytypedefs Global array typedefs
* \ingroup Core_Module
*
- * Eigen defines several typedef shortcuts for most common 1D and 2D array types.
+ * %Eigen defines several typedef shortcuts for most common 1D and 2D array types.
*
* The general patterns are the following:
*
@@ -271,6 +319,12 @@
* There are also \c ArraySizeType which are self-explanatory. For example, \c Array4cf is
* a fixed-size 1D array of 4 complex floats.
*
+ * With \cpp11, template alias are also defined for common sizes.
+ * They follow the same pattern as above except that the scalar type suffix is replaced by a
+ * template parameter, i.e.:
+ * - `ArrayRowsCols<Type>` where `Rows` and `Cols` can be \c 2,\c 3,\c 4, or \c X for fixed or dynamic size.
+ * - `ArraySize<Type>` where `Size` can be \c 2,\c 3,\c 4 or \c X for fixed or dynamic size 1D arrays.
+ *
* \sa class Array
*/
@@ -303,8 +357,42 @@
#undef EIGEN_MAKE_ARRAY_TYPEDEFS_ALL_SIZES
#undef EIGEN_MAKE_ARRAY_TYPEDEFS
+#undef EIGEN_MAKE_ARRAY_FIXED_TYPEDEFS
-#undef EIGEN_MAKE_ARRAY_TYPEDEFS_LARGE
+#if EIGEN_HAS_CXX11
+
+#define EIGEN_MAKE_ARRAY_TYPEDEFS(Size, SizeSuffix) \
+/** \ingroup arraytypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Array##SizeSuffix##SizeSuffix = Array<Type, Size, Size>; \
+/** \ingroup arraytypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Array##SizeSuffix = Array<Type, Size, 1>;
+
+#define EIGEN_MAKE_ARRAY_FIXED_TYPEDEFS(Size) \
+/** \ingroup arraytypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Array##Size##X = Array<Type, Size, Dynamic>; \
+/** \ingroup arraytypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Array##X##Size = Array<Type, Dynamic, Size>;
+
+EIGEN_MAKE_ARRAY_TYPEDEFS(2, 2)
+EIGEN_MAKE_ARRAY_TYPEDEFS(3, 3)
+EIGEN_MAKE_ARRAY_TYPEDEFS(4, 4)
+EIGEN_MAKE_ARRAY_TYPEDEFS(Dynamic, X)
+EIGEN_MAKE_ARRAY_FIXED_TYPEDEFS(2)
+EIGEN_MAKE_ARRAY_FIXED_TYPEDEFS(3)
+EIGEN_MAKE_ARRAY_FIXED_TYPEDEFS(4)
+
+#undef EIGEN_MAKE_ARRAY_TYPEDEFS
+#undef EIGEN_MAKE_ARRAY_FIXED_TYPEDEFS
+
+#endif // EIGEN_HAS_CXX11
#define EIGEN_USING_ARRAY_TYPEDEFS_FOR_TYPE_AND_SIZE(TypeSuffix, SizeSuffix) \
using Eigen::Matrix##SizeSuffix##TypeSuffix; \
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayBase.h
index 33f644e..ea3dd1c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayBase.h
@@ -69,6 +69,7 @@
using Base::coeff;
using Base::coeffRef;
using Base::lazyAssign;
+ using Base::operator-;
using Base::operator=;
using Base::operator+=;
using Base::operator-=;
@@ -88,7 +89,6 @@
#define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::ArrayBase
#define EIGEN_DOC_UNARY_ADDONS(X,Y)
-# include "../plugins/CommonCwiseUnaryOps.h"
# include "../plugins/MatrixCwiseUnaryOps.h"
# include "../plugins/ArrayCwiseUnaryOps.h"
# include "../plugins/CommonCwiseBinaryOps.h"
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayWrapper.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayWrapper.h
index 688aadd..2e9555b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayWrapper.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ArrayWrapper.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_ARRAYWRAPPER_H
#define EIGEN_ARRAYWRAPPER_H
-namespace Eigen {
+namespace Eigen {
/** \class ArrayWrapper
* \ingroup Core_Module
@@ -60,14 +60,14 @@
EIGEN_DEVICE_FUNC
explicit EIGEN_STRONG_INLINE ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {}
- EIGEN_DEVICE_FUNC
- inline Index rows() const { return m_expression.rows(); }
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return m_expression.cols(); }
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const { return m_expression.outerStride(); }
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const { return m_expression.innerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_expression.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_expression.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return m_expression.outerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); }
@@ -90,9 +90,9 @@
EIGEN_DEVICE_FUNC
inline void evalTo(Dest& dst) const { dst = m_expression; }
- const typename internal::remove_all<NestedExpressionType>::type&
EIGEN_DEVICE_FUNC
- nestedExpression() const
+ const typename internal::remove_all<NestedExpressionType>::type&
+ nestedExpression() const
{
return m_expression;
}
@@ -158,14 +158,14 @@
EIGEN_DEVICE_FUNC
explicit inline MatrixWrapper(ExpressionType& matrix) : m_expression(matrix) {}
- EIGEN_DEVICE_FUNC
- inline Index rows() const { return m_expression.rows(); }
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return m_expression.cols(); }
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const { return m_expression.outerStride(); }
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const { return m_expression.innerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_expression.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_expression.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return m_expression.outerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); }
@@ -185,8 +185,8 @@
}
EIGEN_DEVICE_FUNC
- const typename internal::remove_all<NestedExpressionType>::type&
- nestedExpression() const
+ const typename internal::remove_all<NestedExpressionType>::type&
+ nestedExpression() const
{
return m_expression;
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign.h
index 53806ba..655412e 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign.h
@@ -16,7 +16,7 @@
template<typename Derived>
template<typename OtherDerived>
-EIGEN_STRONG_INLINE Derived& DenseBase<Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>
::lazyAssign(const DenseBase<OtherDerived>& other)
{
enum{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/AssignEvaluator.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/AssignEvaluator.h
index dbe435d..7d76f0c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/AssignEvaluator.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/AssignEvaluator.h
@@ -17,24 +17,24 @@
// This implementation is based on Assign.h
namespace internal {
-
+
/***************************************************************************
* Part 1 : the logic deciding a strategy for traversal and unrolling *
***************************************************************************/
// copy_using_evaluator_traits is based on assign_traits
-template <typename DstEvaluator, typename SrcEvaluator, typename AssignFunc>
+template <typename DstEvaluator, typename SrcEvaluator, typename AssignFunc, int MaxPacketSize = -1>
struct copy_using_evaluator_traits
{
typedef typename DstEvaluator::XprType Dst;
typedef typename Dst::Scalar DstScalar;
-
+
enum {
DstFlags = DstEvaluator::Flags,
SrcFlags = SrcEvaluator::Flags
};
-
+
public:
enum {
DstAlignment = DstEvaluator::Alignment,
@@ -51,13 +51,15 @@
InnerMaxSize = int(Dst::IsVectorAtCompileTime) ? int(Dst::MaxSizeAtCompileTime)
: int(DstFlags)&RowMajorBit ? int(Dst::MaxColsAtCompileTime)
: int(Dst::MaxRowsAtCompileTime),
+ RestrictedInnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(InnerSize,MaxPacketSize),
+ RestrictedLinearSize = EIGEN_SIZE_MIN_PREFER_FIXED(Dst::SizeAtCompileTime,MaxPacketSize),
OuterStride = int(outer_stride_at_compile_time<Dst>::ret),
MaxSizeAtCompileTime = Dst::SizeAtCompileTime
};
// TODO distinguish between linear traversal and inner-traversals
- typedef typename find_best_packet<DstScalar,Dst::SizeAtCompileTime>::type LinearPacketType;
- typedef typename find_best_packet<DstScalar,InnerSize>::type InnerPacketType;
+ typedef typename find_best_packet<DstScalar,RestrictedLinearSize>::type LinearPacketType;
+ typedef typename find_best_packet<DstScalar,RestrictedInnerSize>::type InnerPacketType;
enum {
LinearPacketSize = unpacket_traits<LinearPacketType>::size,
@@ -97,7 +99,8 @@
public:
enum {
- Traversal = int(MayLinearVectorize) && (LinearPacketSize>InnerPacketSize) ? int(LinearVectorizedTraversal)
+ Traversal = int(Dst::SizeAtCompileTime) == 0 ? int(AllAtOnceTraversal) // If compile-size is zero, traversing will fail at compile-time.
+ : (int(MayLinearVectorize) && (LinearPacketSize>InnerPacketSize)) ? int(LinearVectorizedTraversal)
: int(MayInnerVectorize) ? int(InnerVectorizedTraversal)
: int(MayLinearVectorize) ? int(LinearVectorizedTraversal)
: int(MaySliceVectorize) ? int(SliceVectorizedTraversal)
@@ -135,7 +138,7 @@
? int(CompleteUnrolling)
: int(NoUnrolling) )
: int(Traversal) == int(LinearTraversal)
- ? ( bool(MayUnrollCompletely) ? int(CompleteUnrolling)
+ ? ( bool(MayUnrollCompletely) ? int(CompleteUnrolling)
: int(NoUnrolling) )
#if EIGEN_UNALIGNED_VECTORIZE
: int(Traversal) == int(SliceVectorizedTraversal)
@@ -172,6 +175,8 @@
EIGEN_DEBUG_VAR(MaySliceVectorize)
std::cerr << "Traversal" << " = " << Traversal << " (" << demangle_traversal(Traversal) << ")" << std::endl;
EIGEN_DEBUG_VAR(SrcEvaluator::CoeffReadCost)
+ EIGEN_DEBUG_VAR(DstEvaluator::CoeffReadCost)
+ EIGEN_DEBUG_VAR(Dst::SizeAtCompileTime)
EIGEN_DEBUG_VAR(UnrollingLimit)
EIGEN_DEBUG_VAR(MayUnrollCompletely)
EIGEN_DEBUG_VAR(MayUnrollInner)
@@ -195,7 +200,7 @@
// FIXME: this is not very clean, perhaps this information should be provided by the kernel?
typedef typename Kernel::DstEvaluatorType DstEvaluatorType;
typedef typename DstEvaluatorType::XprType DstXprType;
-
+
enum {
outer = Index / DstXprType::InnerSizeAtCompileTime,
inner = Index % DstXprType::InnerSizeAtCompileTime
@@ -261,7 +266,7 @@
typedef typename Kernel::DstEvaluatorType DstEvaluatorType;
typedef typename DstEvaluatorType::XprType DstXprType;
typedef typename Kernel::PacketType PacketType;
-
+
enum {
outer = Index / DstXprType::InnerSizeAtCompileTime,
inner = Index % DstXprType::InnerSizeAtCompileTime,
@@ -313,6 +318,22 @@
struct dense_assignment_loop;
/************************
+***** Special Cases *****
+************************/
+
+// Zero-sized assignment is a no-op.
+template<typename Kernel, int Unrolling>
+struct dense_assignment_loop<Kernel, AllAtOnceTraversal, Unrolling>
+{
+ EIGEN_DEVICE_FUNC static void EIGEN_STRONG_INLINE run(Kernel& /*kernel*/)
+ {
+ typedef typename Kernel::DstEvaluatorType::XprType DstXprType;
+ EIGEN_STATIC_ASSERT(int(DstXprType::SizeAtCompileTime) == 0,
+ EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT)
+ }
+};
+
+/************************
*** Default traversal ***
************************/
@@ -426,10 +447,10 @@
{
typedef typename Kernel::DstEvaluatorType::XprType DstXprType;
typedef typename Kernel::PacketType PacketType;
-
+
enum { size = DstXprType::SizeAtCompileTime,
packetSize =unpacket_traits<PacketType>::size,
- alignedSize = (size/packetSize)*packetSize };
+ alignedSize = (int(size)/packetSize)*packetSize };
copy_using_evaluator_innervec_CompleteUnrolling<Kernel, 0, alignedSize>::run(kernel);
copy_using_evaluator_DefaultTraversal_CompleteUnrolling<Kernel, alignedSize, size>::run(kernel);
@@ -530,7 +551,7 @@
const Scalar *dst_ptr = kernel.dstDataPtr();
if((!bool(dstIsAligned)) && (UIntPtr(dst_ptr) % sizeof(Scalar))>0)
{
- // the pointer is not aligend-on scalar, so alignment is not possible
+ // the pointer is not aligned-on scalar, so alignment is not possible
return dense_assignment_loop<Kernel,DefaultTraversal,NoUnrolling>::run(kernel);
}
const Index packetAlignedMask = packetSize - 1;
@@ -568,14 +589,15 @@
typedef typename Kernel::DstEvaluatorType::XprType DstXprType;
typedef typename Kernel::PacketType PacketType;
- enum { size = DstXprType::InnerSizeAtCompileTime,
+ enum { innerSize = DstXprType::InnerSizeAtCompileTime,
packetSize =unpacket_traits<PacketType>::size,
- vectorizableSize = (size/packetSize)*packetSize };
+ vectorizableSize = (int(innerSize) / int(packetSize)) * int(packetSize),
+ size = DstXprType::SizeAtCompileTime };
for(Index outer = 0; outer < kernel.outerSize(); ++outer)
{
copy_using_evaluator_innervec_InnerUnrolling<Kernel, 0, vectorizableSize, 0, 0>::run(kernel, outer);
- copy_using_evaluator_DefaultTraversal_InnerUnrolling<Kernel, vectorizableSize, size>::run(kernel, outer);
+ copy_using_evaluator_DefaultTraversal_InnerUnrolling<Kernel, vectorizableSize, innerSize>::run(kernel, outer);
}
}
};
@@ -599,73 +621,74 @@
typedef typename DstEvaluatorTypeT::XprType DstXprType;
typedef typename SrcEvaluatorTypeT::XprType SrcXprType;
public:
-
+
typedef DstEvaluatorTypeT DstEvaluatorType;
typedef SrcEvaluatorTypeT SrcEvaluatorType;
typedef typename DstEvaluatorType::Scalar Scalar;
typedef copy_using_evaluator_traits<DstEvaluatorTypeT, SrcEvaluatorTypeT, Functor> AssignmentTraits;
typedef typename AssignmentTraits::PacketType PacketType;
-
-
- EIGEN_DEVICE_FUNC generic_dense_assignment_kernel(DstEvaluatorType &dst, const SrcEvaluatorType &src, const Functor &func, DstXprType& dstExpr)
+
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ generic_dense_assignment_kernel(DstEvaluatorType &dst, const SrcEvaluatorType &src, const Functor &func, DstXprType& dstExpr)
: m_dst(dst), m_src(src), m_functor(func), m_dstExpr(dstExpr)
{
#ifdef EIGEN_DEBUG_ASSIGN
AssignmentTraits::debug();
#endif
}
-
- EIGEN_DEVICE_FUNC Index size() const { return m_dstExpr.size(); }
- EIGEN_DEVICE_FUNC Index innerSize() const { return m_dstExpr.innerSize(); }
- EIGEN_DEVICE_FUNC Index outerSize() const { return m_dstExpr.outerSize(); }
- EIGEN_DEVICE_FUNC Index rows() const { return m_dstExpr.rows(); }
- EIGEN_DEVICE_FUNC Index cols() const { return m_dstExpr.cols(); }
- EIGEN_DEVICE_FUNC Index outerStride() const { return m_dstExpr.outerStride(); }
-
- EIGEN_DEVICE_FUNC DstEvaluatorType& dstEvaluator() { return m_dst; }
- EIGEN_DEVICE_FUNC const SrcEvaluatorType& srcEvaluator() const { return m_src; }
-
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_dstExpr.size(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index innerSize() const EIGEN_NOEXCEPT { return m_dstExpr.innerSize(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index outerSize() const EIGEN_NOEXCEPT { return m_dstExpr.outerSize(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_dstExpr.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_dstExpr.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index outerStride() const EIGEN_NOEXCEPT { return m_dstExpr.outerStride(); }
+
+ EIGEN_DEVICE_FUNC DstEvaluatorType& dstEvaluator() EIGEN_NOEXCEPT { return m_dst; }
+ EIGEN_DEVICE_FUNC const SrcEvaluatorType& srcEvaluator() const EIGEN_NOEXCEPT { return m_src; }
+
/// Assign src(row,col) to dst(row,col) through the assignment functor.
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(Index row, Index col)
{
m_functor.assignCoeff(m_dst.coeffRef(row,col), m_src.coeff(row,col));
}
-
+
/// \sa assignCoeff(Index,Index)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(Index index)
{
m_functor.assignCoeff(m_dst.coeffRef(index), m_src.coeff(index));
}
-
+
/// \sa assignCoeff(Index,Index)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeffByOuterInner(Index outer, Index inner)
{
- Index row = rowIndexByOuterInner(outer, inner);
- Index col = colIndexByOuterInner(outer, inner);
+ Index row = rowIndexByOuterInner(outer, inner);
+ Index col = colIndexByOuterInner(outer, inner);
assignCoeff(row, col);
}
-
-
+
+
template<int StoreMode, int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignPacket(Index row, Index col)
{
m_functor.template assignPacket<StoreMode>(&m_dst.coeffRef(row,col), m_src.template packet<LoadMode,PacketType>(row,col));
}
-
+
template<int StoreMode, int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignPacket(Index index)
{
m_functor.template assignPacket<StoreMode>(&m_dst.coeffRef(index), m_src.template packet<LoadMode,PacketType>(index));
}
-
+
template<int StoreMode, int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignPacketByOuterInner(Index outer, Index inner)
{
- Index row = rowIndexByOuterInner(outer, inner);
+ Index row = rowIndexByOuterInner(outer, inner);
Index col = colIndexByOuterInner(outer, inner);
assignPacket<StoreMode,LoadMode,PacketType>(row, col);
}
-
+
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index rowIndexByOuterInner(Index outer, Index inner)
{
typedef typename DstEvaluatorType::ExpressionTraits Traits;
@@ -688,7 +711,7 @@
{
return m_dstExpr.data();
}
-
+
protected:
DstEvaluatorType& m_dst;
const SrcEvaluatorType& m_src;
@@ -697,6 +720,27 @@
DstXprType& m_dstExpr;
};
+// Special kernel used when computing small products whose operands have dynamic dimensions. It ensures that the
+// PacketSize used is no larger than 4, thereby increasing the chance that vectorized instructions will be used
+// when computing the product.
+
+template<typename DstEvaluatorTypeT, typename SrcEvaluatorTypeT, typename Functor>
+class restricted_packet_dense_assignment_kernel : public generic_dense_assignment_kernel<DstEvaluatorTypeT, SrcEvaluatorTypeT, Functor, BuiltIn>
+{
+protected:
+ typedef generic_dense_assignment_kernel<DstEvaluatorTypeT, SrcEvaluatorTypeT, Functor, BuiltIn> Base;
+ public:
+ typedef typename Base::Scalar Scalar;
+ typedef typename Base::DstXprType DstXprType;
+ typedef copy_using_evaluator_traits<DstEvaluatorTypeT, SrcEvaluatorTypeT, Functor, 4> AssignmentTraits;
+ typedef typename AssignmentTraits::PacketType PacketType;
+
+ EIGEN_DEVICE_FUNC restricted_packet_dense_assignment_kernel(DstEvaluatorTypeT &dst, const SrcEvaluatorTypeT &src, const Functor &func, DstXprType& dstExpr)
+ : Base(dst, src, func, dstExpr)
+ {
+ }
+ };
+
/***************************************************************************
* Part 5 : Entry point for dense rectangular assignment
***************************************************************************/
@@ -734,13 +778,23 @@
resize_if_allowed(dst, src, func);
DstEvaluatorType dstEvaluator(dst);
-
+
typedef generic_dense_assignment_kernel<DstEvaluatorType,SrcEvaluatorType,Functor> Kernel;
Kernel kernel(dstEvaluator, srcEvaluator, func, dst.const_cast_derived());
dense_assignment_loop<Kernel>::run(kernel);
}
+// Specialization for filling the destination with a constant value.
+#ifndef EIGEN_GPU_COMPILE_PHASE
+template<typename DstXprType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void call_dense_assignment_loop(DstXprType& dst, const Eigen::CwiseNullaryOp<Eigen::internal::scalar_constant_op<typename DstXprType::Scalar>, DstXprType>& src, const internal::assign_op<typename DstXprType::Scalar,typename DstXprType::Scalar>& func)
+{
+ resize_if_allowed(dst, src, func);
+ std::fill_n(dst.data(), dst.size(), src.functor()());
+}
+#endif
+
template<typename DstXprType, typename SrcXprType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void call_dense_assignment_loop(DstXprType& dst, const SrcXprType& src)
{
@@ -756,13 +810,13 @@
// AssignmentKind must define a Kind typedef.
template<typename DstShape, typename SrcShape> struct AssignmentKind;
-// Assignement kind defined in this file:
+// Assignment kind defined in this file:
struct Dense2Dense {};
struct EigenBase2EigenBase {};
template<typename,typename> struct AssignmentKind { typedef EigenBase2EigenBase Kind; };
template<> struct AssignmentKind<DenseShape,DenseShape> { typedef Dense2Dense Kind; };
-
+
// This is the main assignment class
template< typename DstXprType, typename SrcXprType, typename Functor,
typename Kind = typename AssignmentKind< typename evaluator_traits<DstXprType>::Shape , typename evaluator_traits<SrcXprType>::Shape >::Kind,
@@ -787,7 +841,7 @@
{
call_assignment(dst, src, internal::assign_op<typename Dst::Scalar,typename Src::Scalar>());
}
-
+
// Deal with "assume-aliasing"
template<typename Dst, typename Src, typename Func>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -827,14 +881,35 @@
typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst>::type ActualDstTypeCleaned;
typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst&>::type ActualDstType;
ActualDstType actualDst(dst);
-
+
// TODO check whether this is the right place to perform these checks:
EIGEN_STATIC_ASSERT_LVALUE(Dst)
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(ActualDstTypeCleaned,Src)
EIGEN_CHECK_BINARY_COMPATIBILIY(Func,typename ActualDstTypeCleaned::Scalar,typename Src::Scalar);
-
+
Assignment<ActualDstTypeCleaned,Src,Func>::run(actualDst, src, func);
}
+
+template<typename Dst, typename Src, typename Func>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+void call_restricted_packet_assignment_no_alias(Dst& dst, const Src& src, const Func& func)
+{
+ typedef evaluator<Dst> DstEvaluatorType;
+ typedef evaluator<Src> SrcEvaluatorType;
+ typedef restricted_packet_dense_assignment_kernel<DstEvaluatorType,SrcEvaluatorType,Func> Kernel;
+
+ EIGEN_STATIC_ASSERT_LVALUE(Dst)
+ EIGEN_CHECK_BINARY_COMPATIBILIY(Func,typename Dst::Scalar,typename Src::Scalar);
+
+ SrcEvaluatorType srcEvaluator(src);
+ resize_if_allowed(dst, src, func);
+
+ DstEvaluatorType dstEvaluator(dst);
+ Kernel kernel(dstEvaluator, srcEvaluator, func, dst.const_cast_derived());
+
+ dense_assignment_loop<Kernel>::run(kernel);
+}
+
template<typename Dst, typename Src>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void call_assignment_no_alias(Dst& dst, const Src& src)
@@ -875,7 +950,7 @@
#ifndef EIGEN_NO_DEBUG
internal::check_for_aliasing(dst, src);
#endif
-
+
call_dense_assignment_loop(dst, src, func);
}
};
@@ -899,7 +974,7 @@
src.evalTo(dst);
}
- // NOTE The following two functions are templated to avoid their instanciation if not needed
+ // NOTE The following two functions are templated to avoid their instantiation if not needed
// This is needed because some expressions supports evalTo only and/or have 'void' as scalar type.
template<typename SrcScalarType>
EIGEN_DEVICE_FUNC
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign_MKL.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign_MKL.h
deleted file mode 100644
index 6866095..0000000
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Assign_MKL.h
+++ /dev/null
@@ -1,178 +0,0 @@
-/*
- Copyright (c) 2011, Intel Corporation. All rights reserved.
- Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
-
- Redistribution and use in source and binary forms, with or without modification,
- are permitted provided that the following conditions are met:
-
- * Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
- * Neither the name of Intel Corporation nor the names of its contributors may
- be used to endorse or promote products derived from this software without
- specific prior written permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
- ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
- ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
- ********************************************************************************
- * Content : Eigen bindings to Intel(R) MKL
- * MKL VML support for coefficient-wise unary Eigen expressions like a=b.sin()
- ********************************************************************************
-*/
-
-#ifndef EIGEN_ASSIGN_VML_H
-#define EIGEN_ASSIGN_VML_H
-
-namespace Eigen {
-
-namespace internal {
-
-template<typename Dst, typename Src>
-class vml_assign_traits
-{
- private:
- enum {
- DstHasDirectAccess = Dst::Flags & DirectAccessBit,
- SrcHasDirectAccess = Src::Flags & DirectAccessBit,
- StorageOrdersAgree = (int(Dst::IsRowMajor) == int(Src::IsRowMajor)),
- InnerSize = int(Dst::IsVectorAtCompileTime) ? int(Dst::SizeAtCompileTime)
- : int(Dst::Flags)&RowMajorBit ? int(Dst::ColsAtCompileTime)
- : int(Dst::RowsAtCompileTime),
- InnerMaxSize = int(Dst::IsVectorAtCompileTime) ? int(Dst::MaxSizeAtCompileTime)
- : int(Dst::Flags)&RowMajorBit ? int(Dst::MaxColsAtCompileTime)
- : int(Dst::MaxRowsAtCompileTime),
- MaxSizeAtCompileTime = Dst::SizeAtCompileTime,
-
- MightEnableVml = StorageOrdersAgree && DstHasDirectAccess && SrcHasDirectAccess && Src::InnerStrideAtCompileTime==1 && Dst::InnerStrideAtCompileTime==1,
- MightLinearize = MightEnableVml && (int(Dst::Flags) & int(Src::Flags) & LinearAccessBit),
- VmlSize = MightLinearize ? MaxSizeAtCompileTime : InnerMaxSize,
- LargeEnough = VmlSize==Dynamic || VmlSize>=EIGEN_MKL_VML_THRESHOLD
- };
- public:
- enum {
- EnableVml = MightEnableVml && LargeEnough,
- Traversal = MightLinearize ? LinearTraversal : DefaultTraversal
- };
-};
-
-#define EIGEN_PP_EXPAND(ARG) ARG
-#if !defined (EIGEN_FAST_MATH) || (EIGEN_FAST_MATH != 1)
-#define EIGEN_VMLMODE_EXPAND_LA , VML_HA
-#else
-#define EIGEN_VMLMODE_EXPAND_LA , VML_LA
-#endif
-
-#define EIGEN_VMLMODE_EXPAND__
-
-#define EIGEN_VMLMODE_PREFIX_LA vm
-#define EIGEN_VMLMODE_PREFIX__ v
-#define EIGEN_VMLMODE_PREFIX(VMLMODE) EIGEN_CAT(EIGEN_VMLMODE_PREFIX_,VMLMODE)
-
-#define EIGEN_MKL_VML_DECLARE_UNARY_CALL(EIGENOP, VMLOP, EIGENTYPE, VMLTYPE, VMLMODE) \
- template< typename DstXprType, typename SrcXprNested> \
- struct Assignment<DstXprType, CwiseUnaryOp<scalar_##EIGENOP##_op<EIGENTYPE>, SrcXprNested>, assign_op<EIGENTYPE,EIGENTYPE>, \
- Dense2Dense, typename enable_if<vml_assign_traits<DstXprType,SrcXprNested>::EnableVml>::type> { \
- typedef CwiseUnaryOp<scalar_##EIGENOP##_op<EIGENTYPE>, SrcXprNested> SrcXprType; \
- static void run(DstXprType &dst, const SrcXprType &src, const assign_op<EIGENTYPE,EIGENTYPE> &func) { \
- resize_if_allowed(dst, src, func); \
- eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); \
- if(vml_assign_traits<DstXprType,SrcXprNested>::Traversal==LinearTraversal) { \
- VMLOP(dst.size(), (const VMLTYPE*)src.nestedExpression().data(), \
- (VMLTYPE*)dst.data() EIGEN_PP_EXPAND(EIGEN_VMLMODE_EXPAND_##VMLMODE) ); \
- } else { \
- const Index outerSize = dst.outerSize(); \
- for(Index outer = 0; outer < outerSize; ++outer) { \
- const EIGENTYPE *src_ptr = src.IsRowMajor ? &(src.nestedExpression().coeffRef(outer,0)) : \
- &(src.nestedExpression().coeffRef(0, outer)); \
- EIGENTYPE *dst_ptr = dst.IsRowMajor ? &(dst.coeffRef(outer,0)) : &(dst.coeffRef(0, outer)); \
- VMLOP( dst.innerSize(), (const VMLTYPE*)src_ptr, \
- (VMLTYPE*)dst_ptr EIGEN_PP_EXPAND(EIGEN_VMLMODE_EXPAND_##VMLMODE)); \
- } \
- } \
- } \
- }; \
-
-
-#define EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(EIGENOP, VMLOP, VMLMODE) \
- EIGEN_MKL_VML_DECLARE_UNARY_CALL(EIGENOP, EIGEN_CAT(EIGEN_VMLMODE_PREFIX(VMLMODE),s##VMLOP), float, float, VMLMODE) \
- EIGEN_MKL_VML_DECLARE_UNARY_CALL(EIGENOP, EIGEN_CAT(EIGEN_VMLMODE_PREFIX(VMLMODE),d##VMLOP), double, double, VMLMODE)
-
-#define EIGEN_MKL_VML_DECLARE_UNARY_CALLS_CPLX(EIGENOP, VMLOP, VMLMODE) \
- EIGEN_MKL_VML_DECLARE_UNARY_CALL(EIGENOP, EIGEN_CAT(EIGEN_VMLMODE_PREFIX(VMLMODE),c##VMLOP), scomplex, MKL_Complex8, VMLMODE) \
- EIGEN_MKL_VML_DECLARE_UNARY_CALL(EIGENOP, EIGEN_CAT(EIGEN_VMLMODE_PREFIX(VMLMODE),z##VMLOP), dcomplex, MKL_Complex16, VMLMODE)
-
-#define EIGEN_MKL_VML_DECLARE_UNARY_CALLS(EIGENOP, VMLOP, VMLMODE) \
- EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(EIGENOP, VMLOP, VMLMODE) \
- EIGEN_MKL_VML_DECLARE_UNARY_CALLS_CPLX(EIGENOP, VMLOP, VMLMODE)
-
-
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(sin, Sin, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(asin, Asin, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(sinh, Sinh, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(cos, Cos, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(acos, Acos, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(cosh, Cosh, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(tan, Tan, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(atan, Atan, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(tanh, Tanh, LA)
-// EIGEN_MKL_VML_DECLARE_UNARY_CALLS(abs, Abs, _)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(exp, Exp, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(log, Ln, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(log10, Log10, LA)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS(sqrt, Sqrt, _)
-
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(square, Sqr, _)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS_CPLX(arg, Arg, _)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(round, Round, _)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(floor, Floor, _)
-EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(ceil, Ceil, _)
-
-#define EIGEN_MKL_VML_DECLARE_POW_CALL(EIGENOP, VMLOP, EIGENTYPE, VMLTYPE, VMLMODE) \
- template< typename DstXprType, typename SrcXprNested, typename Plain> \
- struct Assignment<DstXprType, CwiseBinaryOp<scalar_##EIGENOP##_op<EIGENTYPE,EIGENTYPE>, SrcXprNested, \
- const CwiseNullaryOp<internal::scalar_constant_op<EIGENTYPE>,Plain> >, assign_op<EIGENTYPE,EIGENTYPE>, \
- Dense2Dense, typename enable_if<vml_assign_traits<DstXprType,SrcXprNested>::EnableVml>::type> { \
- typedef CwiseBinaryOp<scalar_##EIGENOP##_op<EIGENTYPE,EIGENTYPE>, SrcXprNested, \
- const CwiseNullaryOp<internal::scalar_constant_op<EIGENTYPE>,Plain> > SrcXprType; \
- static void run(DstXprType &dst, const SrcXprType &src, const assign_op<EIGENTYPE,EIGENTYPE> &func) { \
- resize_if_allowed(dst, src, func); \
- eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); \
- VMLTYPE exponent = reinterpret_cast<const VMLTYPE&>(src.rhs().functor().m_other); \
- if(vml_assign_traits<DstXprType,SrcXprNested>::Traversal==LinearTraversal) \
- { \
- VMLOP( dst.size(), (const VMLTYPE*)src.lhs().data(), exponent, \
- (VMLTYPE*)dst.data() EIGEN_PP_EXPAND(EIGEN_VMLMODE_EXPAND_##VMLMODE) ); \
- } else { \
- const Index outerSize = dst.outerSize(); \
- for(Index outer = 0; outer < outerSize; ++outer) { \
- const EIGENTYPE *src_ptr = src.IsRowMajor ? &(src.lhs().coeffRef(outer,0)) : \
- &(src.lhs().coeffRef(0, outer)); \
- EIGENTYPE *dst_ptr = dst.IsRowMajor ? &(dst.coeffRef(outer,0)) : &(dst.coeffRef(0, outer)); \
- VMLOP( dst.innerSize(), (const VMLTYPE*)src_ptr, exponent, \
- (VMLTYPE*)dst_ptr EIGEN_PP_EXPAND(EIGEN_VMLMODE_EXPAND_##VMLMODE)); \
- } \
- } \
- } \
- };
-
-EIGEN_MKL_VML_DECLARE_POW_CALL(pow, vmsPowx, float, float, LA)
-EIGEN_MKL_VML_DECLARE_POW_CALL(pow, vmdPowx, double, double, LA)
-EIGEN_MKL_VML_DECLARE_POW_CALL(pow, vmcPowx, scomplex, MKL_Complex8, LA)
-EIGEN_MKL_VML_DECLARE_POW_CALL(pow, vmzPowx, dcomplex, MKL_Complex16, LA)
-
-} // end namespace internal
-
-} // end namespace Eigen
-
-#endif // EIGEN_ASSIGN_VML_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BandMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BandMatrix.h
index 4978c91..878c024 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BandMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BandMatrix.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_BANDMATRIX_H
#define EIGEN_BANDMATRIX_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
@@ -45,7 +45,7 @@
};
public:
-
+
using Base::derived;
using Base::rows;
using Base::cols;
@@ -55,10 +55,10 @@
/** \returns the number of sub diagonals */
inline Index subs() const { return derived().subs(); }
-
+
/** \returns an expression of the underlying coefficient matrix */
inline const CoefficientsType& coeffs() const { return derived().coeffs(); }
-
+
/** \returns an expression of the underlying coefficient matrix */
inline CoefficientsType& coeffs() { return derived().coeffs(); }
@@ -67,7 +67,7 @@
* \warning the internal storage must be column major. */
inline Block<CoefficientsType,Dynamic,1> col(Index i)
{
- EIGEN_STATIC_ASSERT((Options&RowMajor)==0,THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
+ EIGEN_STATIC_ASSERT((int(Options) & int(RowMajor)) == 0, THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
Index start = 0;
Index len = coeffs().rows();
if (i<=supers())
@@ -90,7 +90,7 @@
template<int Index> struct DiagonalIntReturnType {
enum {
- ReturnOpposite = (Options&SelfAdjoint) && (((Index)>0 && Supers==0) || ((Index)<0 && Subs==0)),
+ ReturnOpposite = (int(Options) & int(SelfAdjoint)) && (((Index) > 0 && Supers == 0) || ((Index) < 0 && Subs == 0)),
Conjugate = ReturnOpposite && NumTraits<Scalar>::IsComplex,
ActualIndex = ReturnOpposite ? -Index : Index,
DiagonalSize = (RowsAtCompileTime==Dynamic || ColsAtCompileTime==Dynamic)
@@ -130,7 +130,7 @@
eigen_assert((i<0 && -i<=subs()) || (i>=0 && i<=supers()));
return Block<const CoefficientsType,1,Dynamic>(coeffs(), supers()-i, std::max<Index>(0,i), 1, diagonalLength(i));
}
-
+
template<typename Dest> inline void evalTo(Dest& dst) const
{
dst.resize(rows(),cols());
@@ -192,7 +192,7 @@
Options = _Options,
DataRowsAtCompileTime = ((Supers!=Dynamic) && (Subs!=Dynamic)) ? 1 + Supers + Subs : Dynamic
};
- typedef Matrix<Scalar,DataRowsAtCompileTime,ColsAtCompileTime,Options&RowMajor?RowMajor:ColMajor> CoefficientsType;
+ typedef Matrix<Scalar, DataRowsAtCompileTime, ColsAtCompileTime, int(Options) & int(RowMajor) ? RowMajor : ColMajor> CoefficientsType;
};
template<typename _Scalar, int Rows, int Cols, int Supers, int Subs, int Options>
@@ -211,16 +211,16 @@
}
/** \returns the number of columns */
- inline Index rows() const { return m_rows.value(); }
+ inline EIGEN_CONSTEXPR Index rows() const { return m_rows.value(); }
/** \returns the number of rows */
- inline Index cols() const { return m_coeffs.cols(); }
+ inline EIGEN_CONSTEXPR Index cols() const { return m_coeffs.cols(); }
/** \returns the number of super diagonals */
- inline Index supers() const { return m_supers.value(); }
+ inline EIGEN_CONSTEXPR Index supers() const { return m_supers.value(); }
/** \returns the number of sub diagonals */
- inline Index subs() const { return m_subs.value(); }
+ inline EIGEN_CONSTEXPR Index subs() const { return m_subs.value(); }
inline const CoefficientsType& coeffs() const { return m_coeffs; }
inline CoefficientsType& coeffs() { return m_coeffs; }
@@ -275,16 +275,16 @@
}
/** \returns the number of columns */
- inline Index rows() const { return m_rows.value(); }
+ inline EIGEN_CONSTEXPR Index rows() const { return m_rows.value(); }
/** \returns the number of rows */
- inline Index cols() const { return m_coeffs.cols(); }
+ inline EIGEN_CONSTEXPR Index cols() const { return m_coeffs.cols(); }
/** \returns the number of super diagonals */
- inline Index supers() const { return m_supers.value(); }
+ inline EIGEN_CONSTEXPR Index supers() const { return m_supers.value(); }
/** \returns the number of sub diagonals */
- inline Index subs() const { return m_subs.value(); }
+ inline EIGEN_CONSTEXPR Index subs() const { return m_subs.value(); }
inline const CoefficientsType& coeffs() const { return m_coeffs; }
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Block.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Block.h
index 11de45c..3206d66 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Block.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Block.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_BLOCK_H
#define EIGEN_BLOCK_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
@@ -52,7 +52,7 @@
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
Flags = (traits<XprType>::Flags & (DirectAccessBit | (InnerPanel?CompressedAccessBit:0))) | FlagsLvalueBit | FlagsRowMajorBit,
// FIXME DirectAccessBit should not be handled by expressions
- //
+ //
// Alignment is needed by MapBase's assertions
// We can sefely set it to false here. Internal alignment errors will be detected by an eigen_internal_assert in the respective evaluator
Alignment = 0
@@ -61,7 +61,7 @@
template<typename XprType, int BlockRows=Dynamic, int BlockCols=Dynamic, bool InnerPanel = false,
bool HasDirectAccess = internal::has_direct_access<XprType>::ret> class BlockImpl_dense;
-
+
} // end namespace internal
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel, typename StorageKind> class BlockImpl;
@@ -109,13 +109,13 @@
typedef Impl Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(Block)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Block)
-
+
typedef typename internal::remove_all<XprType>::type NestedExpression;
-
+
/** Column or Row constructor
*/
- EIGEN_DEVICE_FUNC
- inline Block(XprType& xpr, Index i) : Impl(xpr,i)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Block(XprType& xpr, Index i) : Impl(xpr,i)
{
eigen_assert( (i>=0) && (
((BlockRows==1) && (BlockCols==XprType::ColsAtCompileTime) && i<xpr.rows())
@@ -124,8 +124,8 @@
/** Fixed-size constructor
*/
- EIGEN_DEVICE_FUNC
- inline Block(XprType& xpr, Index startRow, Index startCol)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Block(XprType& xpr, Index startRow, Index startCol)
: Impl(xpr, startRow, startCol)
{
EIGEN_STATIC_ASSERT(RowsAtCompileTime!=Dynamic && ColsAtCompileTime!=Dynamic,THIS_METHOD_IS_ONLY_FOR_FIXED_SIZE)
@@ -135,8 +135,8 @@
/** Dynamic-size constructor
*/
- EIGEN_DEVICE_FUNC
- inline Block(XprType& xpr,
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Block(XprType& xpr,
Index startRow, Index startCol,
Index blockRows, Index blockCols)
: Impl(xpr, startRow, startCol, blockRows, blockCols)
@@ -147,7 +147,7 @@
&& startCol >= 0 && blockCols >= 0 && startCol <= xpr.cols() - blockCols);
}
};
-
+
// The generic default implementation for dense block simplu forward to the internal::BlockImpl_dense
// that must be specialized for direct and non-direct access...
template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
@@ -159,10 +159,10 @@
public:
typedef Impl Base;
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(BlockImpl)
- EIGEN_DEVICE_FUNC inline BlockImpl(XprType& xpr, Index i) : Impl(xpr,i) {}
- EIGEN_DEVICE_FUNC inline BlockImpl(XprType& xpr, Index startRow, Index startCol) : Impl(xpr, startRow, startCol) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE BlockImpl(XprType& xpr, Index i) : Impl(xpr,i) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE BlockImpl(XprType& xpr, Index startRow, Index startCol) : Impl(xpr, startRow, startCol) {}
EIGEN_DEVICE_FUNC
- inline BlockImpl(XprType& xpr, Index startRow, Index startCol, Index blockRows, Index blockCols)
+ EIGEN_STRONG_INLINE BlockImpl(XprType& xpr, Index startRow, Index startCol, Index blockRows, Index blockCols)
: Impl(xpr, startRow, startCol, blockRows, blockCols) {}
};
@@ -294,25 +294,25 @@
EIGEN_DEVICE_FUNC inline Index outerStride() const;
#endif
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
- {
- return m_xpr;
+ {
+ return m_xpr;
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
XprType& nestedExpression() { return m_xpr; }
-
- EIGEN_DEVICE_FUNC
- StorageIndex startRow() const
- {
- return m_startRow.value();
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ StorageIndex startRow() const EIGEN_NOEXCEPT
+ {
+ return m_startRow.value();
}
-
- EIGEN_DEVICE_FUNC
- StorageIndex startCol() const
- {
- return m_startCol.value();
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ StorageIndex startCol() const EIGEN_NOEXCEPT
+ {
+ return m_startCol.value();
}
protected:
@@ -342,9 +342,9 @@
/** Column or Row constructor
*/
- EIGEN_DEVICE_FUNC
- inline BlockImpl_dense(XprType& xpr, Index i)
- : Base(xpr.data() + i * ( ((BlockRows==1) && (BlockCols==XprType::ColsAtCompileTime) && (!XprTypeIsRowMajor))
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ BlockImpl_dense(XprType& xpr, Index i)
+ : Base(xpr.data() + i * ( ((BlockRows==1) && (BlockCols==XprType::ColsAtCompileTime) && (!XprTypeIsRowMajor))
|| ((BlockRows==XprType::RowsAtCompileTime) && (BlockCols==1) && ( XprTypeIsRowMajor)) ? xpr.innerStride() : xpr.outerStride()),
BlockRows==1 ? 1 : xpr.rows(),
BlockCols==1 ? 1 : xpr.cols()),
@@ -357,8 +357,8 @@
/** Fixed-size constructor
*/
- EIGEN_DEVICE_FUNC
- inline BlockImpl_dense(XprType& xpr, Index startRow, Index startCol)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ BlockImpl_dense(XprType& xpr, Index startRow, Index startCol)
: Base(xpr.data()+xpr.innerStride()*(XprTypeIsRowMajor?startCol:startRow) + xpr.outerStride()*(XprTypeIsRowMajor?startRow:startCol)),
m_xpr(xpr), m_startRow(startRow), m_startCol(startCol)
{
@@ -367,8 +367,8 @@
/** Dynamic-size constructor
*/
- EIGEN_DEVICE_FUNC
- inline BlockImpl_dense(XprType& xpr,
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ BlockImpl_dense(XprType& xpr,
Index startRow, Index startCol,
Index blockRows, Index blockCols)
: Base(xpr.data()+xpr.innerStride()*(XprTypeIsRowMajor?startCol:startRow) + xpr.outerStride()*(XprTypeIsRowMajor?startRow:startCol), blockRows, blockCols),
@@ -377,18 +377,18 @@
init();
}
- EIGEN_DEVICE_FUNC
- const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
- {
- return m_xpr;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const EIGEN_NOEXCEPT
+ {
+ return m_xpr;
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
XprType& nestedExpression() { return m_xpr; }
-
+
/** \sa MapBase::innerStride() */
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index innerStride() const EIGEN_NOEXCEPT
{
return internal::traits<BlockType>::HasSameStorageOrderAsXprType
? m_xpr.innerStride()
@@ -396,23 +396,19 @@
}
/** \sa MapBase::outerStride() */
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index outerStride() const EIGEN_NOEXCEPT
{
- return m_outerStride;
+ return internal::traits<BlockType>::HasSameStorageOrderAsXprType
+ ? m_xpr.outerStride()
+ : m_xpr.innerStride();
}
- EIGEN_DEVICE_FUNC
- StorageIndex startRow() const
- {
- return m_startRow.value();
- }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ StorageIndex startRow() const EIGEN_NOEXCEPT { return m_startRow.value(); }
- EIGEN_DEVICE_FUNC
- StorageIndex startCol() const
- {
- return m_startCol.value();
- }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ StorageIndex startCol() const EIGEN_NOEXCEPT { return m_startCol.value(); }
#ifndef __SUNPRO_CC
// FIXME sunstudio is not friendly with the above friend...
@@ -422,8 +418,8 @@
#ifndef EIGEN_PARSED_BY_DOXYGEN
/** \internal used by allowAligned() */
- EIGEN_DEVICE_FUNC
- inline BlockImpl_dense(XprType& xpr, const Scalar* data, Index blockRows, Index blockCols)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ BlockImpl_dense(XprType& xpr, const Scalar* data, Index blockRows, Index blockCols)
: Base(data, blockRows, blockCols), m_xpr(xpr)
{
init();
@@ -431,7 +427,7 @@
#endif
protected:
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void init()
{
m_outerStride = internal::traits<BlockType>::HasSameStorageOrderAsXprType
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BooleanRedux.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BooleanRedux.h
index 8409d87..852de8b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BooleanRedux.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/BooleanRedux.h
@@ -14,58 +14,56 @@
namespace internal {
-template<typename Derived, int UnrollCount>
+template<typename Derived, int UnrollCount, int Rows>
struct all_unroller
{
- typedef typename Derived::ExpressionTraits Traits;
enum {
- col = (UnrollCount-1) / Traits::RowsAtCompileTime,
- row = (UnrollCount-1) % Traits::RowsAtCompileTime
+ col = (UnrollCount-1) / Rows,
+ row = (UnrollCount-1) % Rows
};
- static inline bool run(const Derived &mat)
+ EIGEN_DEVICE_FUNC static inline bool run(const Derived &mat)
{
- return all_unroller<Derived, UnrollCount-1>::run(mat) && mat.coeff(row, col);
+ return all_unroller<Derived, UnrollCount-1, Rows>::run(mat) && mat.coeff(row, col);
}
};
-template<typename Derived>
-struct all_unroller<Derived, 0>
+template<typename Derived, int Rows>
+struct all_unroller<Derived, 0, Rows>
{
- static inline bool run(const Derived &/*mat*/) { return true; }
+ EIGEN_DEVICE_FUNC static inline bool run(const Derived &/*mat*/) { return true; }
};
-template<typename Derived>
-struct all_unroller<Derived, Dynamic>
+template<typename Derived, int Rows>
+struct all_unroller<Derived, Dynamic, Rows>
{
- static inline bool run(const Derived &) { return false; }
+ EIGEN_DEVICE_FUNC static inline bool run(const Derived &) { return false; }
};
-template<typename Derived, int UnrollCount>
+template<typename Derived, int UnrollCount, int Rows>
struct any_unroller
{
- typedef typename Derived::ExpressionTraits Traits;
enum {
- col = (UnrollCount-1) / Traits::RowsAtCompileTime,
- row = (UnrollCount-1) % Traits::RowsAtCompileTime
+ col = (UnrollCount-1) / Rows,
+ row = (UnrollCount-1) % Rows
};
- static inline bool run(const Derived &mat)
+ EIGEN_DEVICE_FUNC static inline bool run(const Derived &mat)
{
- return any_unroller<Derived, UnrollCount-1>::run(mat) || mat.coeff(row, col);
+ return any_unroller<Derived, UnrollCount-1, Rows>::run(mat) || mat.coeff(row, col);
}
};
-template<typename Derived>
-struct any_unroller<Derived, 0>
+template<typename Derived, int Rows>
+struct any_unroller<Derived, 0, Rows>
{
- static inline bool run(const Derived & /*mat*/) { return false; }
+ EIGEN_DEVICE_FUNC static inline bool run(const Derived & /*mat*/) { return false; }
};
-template<typename Derived>
-struct any_unroller<Derived, Dynamic>
+template<typename Derived, int Rows>
+struct any_unroller<Derived, Dynamic, Rows>
{
- static inline bool run(const Derived &) { return false; }
+ EIGEN_DEVICE_FUNC static inline bool run(const Derived &) { return false; }
};
} // end namespace internal
@@ -78,16 +76,16 @@
* \sa any(), Cwise::operator<()
*/
template<typename Derived>
-inline bool DenseBase<Derived>::all() const
+EIGEN_DEVICE_FUNC inline bool DenseBase<Derived>::all() const
{
typedef internal::evaluator<Derived> Evaluator;
enum {
unroll = SizeAtCompileTime != Dynamic
- && SizeAtCompileTime * (Evaluator::CoeffReadCost + NumTraits<Scalar>::AddCost) <= EIGEN_UNROLLING_LIMIT
+ && SizeAtCompileTime * (int(Evaluator::CoeffReadCost) + int(NumTraits<Scalar>::AddCost)) <= EIGEN_UNROLLING_LIMIT
};
Evaluator evaluator(derived());
if(unroll)
- return internal::all_unroller<Evaluator, unroll ? int(SizeAtCompileTime) : Dynamic>::run(evaluator);
+ return internal::all_unroller<Evaluator, unroll ? int(SizeAtCompileTime) : Dynamic, internal::traits<Derived>::RowsAtCompileTime>::run(evaluator);
else
{
for(Index j = 0; j < cols(); ++j)
@@ -102,16 +100,16 @@
* \sa all()
*/
template<typename Derived>
-inline bool DenseBase<Derived>::any() const
+EIGEN_DEVICE_FUNC inline bool DenseBase<Derived>::any() const
{
typedef internal::evaluator<Derived> Evaluator;
enum {
unroll = SizeAtCompileTime != Dynamic
- && SizeAtCompileTime * (Evaluator::CoeffReadCost + NumTraits<Scalar>::AddCost) <= EIGEN_UNROLLING_LIMIT
+ && SizeAtCompileTime * (int(Evaluator::CoeffReadCost) + int(NumTraits<Scalar>::AddCost)) <= EIGEN_UNROLLING_LIMIT
};
Evaluator evaluator(derived());
if(unroll)
- return internal::any_unroller<Evaluator, unroll ? int(SizeAtCompileTime) : Dynamic>::run(evaluator);
+ return internal::any_unroller<Evaluator, unroll ? int(SizeAtCompileTime) : Dynamic, internal::traits<Derived>::RowsAtCompileTime>::run(evaluator);
else
{
for(Index j = 0; j < cols(); ++j)
@@ -126,7 +124,7 @@
* \sa all(), any()
*/
template<typename Derived>
-inline Eigen::Index DenseBase<Derived>::count() const
+EIGEN_DEVICE_FUNC inline Eigen::Index DenseBase<Derived>::count() const
{
return derived().template cast<bool>().template cast<Index>().sum();
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CommaInitializer.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CommaInitializer.h
index d218e98..c0e29c7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CommaInitializer.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CommaInitializer.h
@@ -33,6 +33,8 @@
inline CommaInitializer(XprType& xpr, const Scalar& s)
: m_xpr(xpr), m_row(0), m_col(1), m_currentBlockRows(1)
{
+ eigen_assert(m_xpr.rows() > 0 && m_xpr.cols() > 0
+ && "Cannot comma-initialize a 0x0 matrix (operator<<)");
m_xpr.coeffRef(0,0) = s;
}
@@ -41,6 +43,8 @@
inline CommaInitializer(XprType& xpr, const DenseBase<OtherDerived>& other)
: m_xpr(xpr), m_row(0), m_col(other.cols()), m_currentBlockRows(other.rows())
{
+ eigen_assert(m_xpr.rows() >= other.rows() && m_xpr.cols() >= other.cols()
+ && "Cannot comma-initialize a 0x0 matrix (operator<<)");
m_xpr.block(0, 0, other.rows(), other.cols()) = other;
}
@@ -103,7 +107,7 @@
EIGEN_EXCEPTION_SPEC(Eigen::eigen_assert_exception)
#endif
{
- finished();
+ finished();
}
/** \returns the built matrix once all its coefficients have been set.
@@ -141,7 +145,7 @@
* \sa CommaInitializer::finished(), class CommaInitializer
*/
template<typename Derived>
-inline CommaInitializer<Derived> DenseBase<Derived>::operator<< (const Scalar& s)
+EIGEN_DEVICE_FUNC inline CommaInitializer<Derived> DenseBase<Derived>::operator<< (const Scalar& s)
{
return CommaInitializer<Derived>(*static_cast<Derived*>(this), s);
}
@@ -149,7 +153,7 @@
/** \sa operator<<(const Scalar&) */
template<typename Derived>
template<typename OtherDerived>
-inline CommaInitializer<Derived>
+EIGEN_DEVICE_FUNC inline CommaInitializer<Derived>
DenseBase<Derived>::operator<<(const DenseBase<OtherDerived>& other)
{
return CommaInitializer<Derived>(*static_cast<Derived *>(this), other);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreEvaluators.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreEvaluators.h
index 910889e..0ff8c8d 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreEvaluators.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreEvaluators.h
@@ -14,7 +14,7 @@
#define EIGEN_COREEVALUATORS_H
namespace Eigen {
-
+
namespace internal {
// This class returns the evaluator kind from the expression storage kind.
@@ -63,8 +63,8 @@
template< typename T,
typename Kind = typename evaluator_traits<typename T::NestedExpression>::Kind,
typename Scalar = typename T::Scalar> struct unary_evaluator;
-
-// evaluator_traits<T> contains traits for evaluator<T>
+
+// evaluator_traits<T> contains traits for evaluator<T>
template<typename T>
struct evaluator_traits_base
@@ -90,7 +90,8 @@
struct evaluator : public unary_evaluator<T>
{
typedef unary_evaluator<T> Base;
- EIGEN_DEVICE_FUNC explicit evaluator(const T& xpr) : Base(xpr) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const T& xpr) : Base(xpr) {}
};
@@ -99,21 +100,29 @@
struct evaluator<const T>
: evaluator<T>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
explicit evaluator(const T& xpr) : evaluator<T>(xpr) {}
};
// ---------- base class for all evaluators ----------
template<typename ExpressionType>
-struct evaluator_base : public noncopyable
+struct evaluator_base
{
// TODO that's not very nice to have to propagate all these traits. They are currently only needed to handle outer,inner indices.
typedef traits<ExpressionType> ExpressionTraits;
-
+
enum {
Alignment = 0
};
+ // noncopyable:
+ // Don't make this class inherit noncopyable as this kills EBO (Empty Base Optimization)
+ // and make complex evaluator much larger than then should do.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE evaluator_base() {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ~evaluator_base() {}
+private:
+ EIGEN_DEVICE_FUNC evaluator_base(const evaluator_base&);
+ EIGEN_DEVICE_FUNC const evaluator_base& operator=(const evaluator_base&);
};
// -------------------- Matrix and Array --------------------
@@ -123,6 +132,33 @@
// Here we directly specialize evaluator. This is not really a unary expression, and it is, by definition, dense,
// so no need for more sophisticated dispatching.
+// this helper permits to completely eliminate m_outerStride if it is known at compiletime.
+template<typename Scalar,int OuterStride> class plainobjectbase_evaluator_data {
+public:
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ plainobjectbase_evaluator_data(const Scalar* ptr, Index outerStride) : data(ptr)
+ {
+#ifndef EIGEN_INTERNAL_DEBUGGING
+ EIGEN_UNUSED_VARIABLE(outerStride);
+#endif
+ eigen_internal_assert(outerStride==OuterStride);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index outerStride() const EIGEN_NOEXCEPT { return OuterStride; }
+ const Scalar *data;
+};
+
+template<typename Scalar> class plainobjectbase_evaluator_data<Scalar,Dynamic> {
+public:
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ plainobjectbase_evaluator_data(const Scalar* ptr, Index outerStride) : data(ptr), m_outerStride(outerStride) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Index outerStride() const { return m_outerStride; }
+ const Scalar *data;
+protected:
+ Index m_outerStride;
+};
+
template<typename Derived>
struct evaluator<PlainObjectBase<Derived> >
: evaluator_base<Derived>
@@ -136,23 +172,28 @@
IsVectorAtCompileTime = PlainObjectType::IsVectorAtCompileTime,
RowsAtCompileTime = PlainObjectType::RowsAtCompileTime,
ColsAtCompileTime = PlainObjectType::ColsAtCompileTime,
-
+
CoeffReadCost = NumTraits<Scalar>::ReadCost,
Flags = traits<Derived>::EvaluatorFlags,
Alignment = traits<Derived>::Alignment
};
-
- EIGEN_DEVICE_FUNC evaluator()
- : m_data(0),
- m_outerStride(IsVectorAtCompileTime ? 0
- : int(IsRowMajor) ? ColsAtCompileTime
- : RowsAtCompileTime)
+ enum {
+ // We do not need to know the outer stride for vectors
+ OuterStrideAtCompileTime = IsVectorAtCompileTime ? 0
+ : int(IsRowMajor) ? ColsAtCompileTime
+ : RowsAtCompileTime
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ evaluator()
+ : m_d(0,OuterStrideAtCompileTime)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
-
- EIGEN_DEVICE_FUNC explicit evaluator(const PlainObjectType& m)
- : m_data(m.data()), m_outerStride(IsVectorAtCompileTime ? 0 : m.outerStride())
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const PlainObjectType& m)
+ : m_d(m.data(),IsVectorAtCompileTime ? 0 : m.outerStride())
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
@@ -161,30 +202,30 @@
CoeffReturnType coeff(Index row, Index col) const
{
if (IsRowMajor)
- return m_data[row * m_outerStride.value() + col];
+ return m_d.data[row * m_d.outerStride() + col];
else
- return m_data[row + col * m_outerStride.value()];
+ return m_d.data[row + col * m_d.outerStride()];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
- return m_data[index];
+ return m_d.data[index];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index row, Index col)
{
if (IsRowMajor)
- return const_cast<Scalar*>(m_data)[row * m_outerStride.value() + col];
+ return const_cast<Scalar*>(m_d.data)[row * m_d.outerStride() + col];
else
- return const_cast<Scalar*>(m_data)[row + col * m_outerStride.value()];
+ return const_cast<Scalar*>(m_d.data)[row + col * m_d.outerStride()];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index index)
{
- return const_cast<Scalar*>(m_data)[index];
+ return const_cast<Scalar*>(m_d.data)[index];
}
template<int LoadMode, typename PacketType>
@@ -192,16 +233,16 @@
PacketType packet(Index row, Index col) const
{
if (IsRowMajor)
- return ploadt<PacketType, LoadMode>(m_data + row * m_outerStride.value() + col);
+ return ploadt<PacketType, LoadMode>(m_d.data + row * m_d.outerStride() + col);
else
- return ploadt<PacketType, LoadMode>(m_data + row + col * m_outerStride.value());
+ return ploadt<PacketType, LoadMode>(m_d.data + row + col * m_d.outerStride());
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index index) const
{
- return ploadt<PacketType, LoadMode>(m_data + index);
+ return ploadt<PacketType, LoadMode>(m_d.data + index);
}
template<int StoreMode,typename PacketType>
@@ -210,26 +251,22 @@
{
if (IsRowMajor)
return pstoret<Scalar, PacketType, StoreMode>
- (const_cast<Scalar*>(m_data) + row * m_outerStride.value() + col, x);
+ (const_cast<Scalar*>(m_d.data) + row * m_d.outerStride() + col, x);
else
return pstoret<Scalar, PacketType, StoreMode>
- (const_cast<Scalar*>(m_data) + row + col * m_outerStride.value(), x);
+ (const_cast<Scalar*>(m_d.data) + row + col * m_d.outerStride(), x);
}
template<int StoreMode, typename PacketType>
EIGEN_STRONG_INLINE
void writePacket(Index index, const PacketType& x)
{
- return pstoret<Scalar, PacketType, StoreMode>(const_cast<Scalar*>(m_data) + index, x);
+ return pstoret<Scalar, PacketType, StoreMode>(const_cast<Scalar*>(m_d.data) + index, x);
}
protected:
- const Scalar *m_data;
- // We do not need to know the outer stride for vectors
- variable_if_dynamic<Index, IsVectorAtCompileTime ? 0
- : int(IsRowMajor) ? ColsAtCompileTime
- : RowsAtCompileTime> m_outerStride;
+ plainobjectbase_evaluator_data<Scalar,OuterStrideAtCompileTime> m_d;
};
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
@@ -237,11 +274,13 @@
: evaluator<PlainObjectBase<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
{
typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
-
- EIGEN_DEVICE_FUNC evaluator() {}
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& m)
- : evaluator<PlainObjectBase<XprType> >(m)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ evaluator() {}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& m)
+ : evaluator<PlainObjectBase<XprType> >(m)
{ }
};
@@ -251,10 +290,12 @@
{
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
- EIGEN_DEVICE_FUNC evaluator() {}
-
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& m)
- : evaluator<PlainObjectBase<XprType> >(m)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ evaluator() {}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& m)
+ : evaluator<PlainObjectBase<XprType> >(m)
{ }
};
@@ -265,14 +306,15 @@
: evaluator_base<Transpose<ArgType> >
{
typedef Transpose<ArgType> XprType;
-
+
enum {
- CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
+ CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
Flags = evaluator<ArgType>::Flags ^ RowMajorBit,
Alignment = evaluator<ArgType>::Alignment
};
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& t) : m_argImpl(t.nestedExpression()) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit unary_evaluator(const XprType& t) : m_argImpl(t.nestedExpression()) {}
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
@@ -457,10 +499,10 @@
{
typedef CwiseNullaryOp<NullaryOp,PlainObjectType> XprType;
typedef typename internal::remove_all<PlainObjectType>::type PlainObjectTypeCleaned;
-
+
enum {
CoeffReadCost = internal::functor_traits<NullaryOp>::Cost,
-
+
Flags = (evaluator<PlainObjectTypeCleaned>::Flags
& ( HereditaryBits
| (functor_has_linear_access<NullaryOp>::ret ? LinearAccessBit : 0)
@@ -517,19 +559,17 @@
: evaluator_base<CwiseUnaryOp<UnaryOp, ArgType> >
{
typedef CwiseUnaryOp<UnaryOp, ArgType> XprType;
-
+
enum {
- CoeffReadCost = evaluator<ArgType>::CoeffReadCost + functor_traits<UnaryOp>::Cost,
-
+ CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),
+
Flags = evaluator<ArgType>::Flags
& (HereditaryBits | LinearAccessBit | (functor_traits<UnaryOp>::PacketAccess ? PacketAccessBit : 0)),
Alignment = evaluator<ArgType>::Alignment
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- explicit unary_evaluator(const XprType& op)
- : m_functor(op.functor()),
- m_argImpl(op.nestedExpression())
+ explicit unary_evaluator(const XprType& op) : m_d(op)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<UnaryOp>::Cost);
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
@@ -540,32 +580,43 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
- return m_functor(m_argImpl.coeff(row, col));
+ return m_d.func()(m_d.argImpl.coeff(row, col));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
- return m_functor(m_argImpl.coeff(index));
+ return m_d.func()(m_d.argImpl.coeff(index));
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index row, Index col) const
{
- return m_functor.packetOp(m_argImpl.template packet<LoadMode, PacketType>(row, col));
+ return m_d.func().packetOp(m_d.argImpl.template packet<LoadMode, PacketType>(row, col));
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index index) const
{
- return m_functor.packetOp(m_argImpl.template packet<LoadMode, PacketType>(index));
+ return m_d.func().packetOp(m_d.argImpl.template packet<LoadMode, PacketType>(index));
}
protected:
- const UnaryOp m_functor;
- evaluator<ArgType> m_argImpl;
+
+ // this helper permits to completely eliminate the functor if it is empty
+ struct Data
+ {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Data(const XprType& xpr) : op(xpr.functor()), argImpl(xpr.nestedExpression()) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const UnaryOp& func() const { return op; }
+ UnaryOp op;
+ evaluator<ArgType> argImpl;
+ };
+
+ Data m_d;
};
// -------------------- CwiseTernaryOp --------------------
@@ -577,7 +628,7 @@
{
typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType;
typedef ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > Base;
-
+
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {}
};
@@ -586,10 +637,10 @@
: evaluator_base<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
{
typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType;
-
+
enum {
- CoeffReadCost = evaluator<Arg1>::CoeffReadCost + evaluator<Arg2>::CoeffReadCost + evaluator<Arg3>::CoeffReadCost + functor_traits<TernaryOp>::Cost,
-
+ CoeffReadCost = int(evaluator<Arg1>::CoeffReadCost) + int(evaluator<Arg2>::CoeffReadCost) + int(evaluator<Arg3>::CoeffReadCost) + int(functor_traits<TernaryOp>::Cost),
+
Arg1Flags = evaluator<Arg1>::Flags,
Arg2Flags = evaluator<Arg2>::Flags,
Arg3Flags = evaluator<Arg3>::Flags,
@@ -609,11 +660,7 @@
evaluator<Arg3>::Alignment)
};
- EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr)
- : m_functor(xpr.functor()),
- m_arg1Impl(xpr.arg1()),
- m_arg2Impl(xpr.arg2()),
- m_arg3Impl(xpr.arg3())
+ EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr) : m_d(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<TernaryOp>::Cost);
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
@@ -624,38 +671,48 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
- return m_functor(m_arg1Impl.coeff(row, col), m_arg2Impl.coeff(row, col), m_arg3Impl.coeff(row, col));
+ return m_d.func()(m_d.arg1Impl.coeff(row, col), m_d.arg2Impl.coeff(row, col), m_d.arg3Impl.coeff(row, col));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
- return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
+ return m_d.func()(m_d.arg1Impl.coeff(index), m_d.arg2Impl.coeff(index), m_d.arg3Impl.coeff(index));
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index row, Index col) const
{
- return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(row, col),
- m_arg2Impl.template packet<LoadMode,PacketType>(row, col),
- m_arg3Impl.template packet<LoadMode,PacketType>(row, col));
+ return m_d.func().packetOp(m_d.arg1Impl.template packet<LoadMode,PacketType>(row, col),
+ m_d.arg2Impl.template packet<LoadMode,PacketType>(row, col),
+ m_d.arg3Impl.template packet<LoadMode,PacketType>(row, col));
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index index) const
{
- return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(index),
- m_arg2Impl.template packet<LoadMode,PacketType>(index),
- m_arg3Impl.template packet<LoadMode,PacketType>(index));
+ return m_d.func().packetOp(m_d.arg1Impl.template packet<LoadMode,PacketType>(index),
+ m_d.arg2Impl.template packet<LoadMode,PacketType>(index),
+ m_d.arg3Impl.template packet<LoadMode,PacketType>(index));
}
protected:
- const TernaryOp m_functor;
- evaluator<Arg1> m_arg1Impl;
- evaluator<Arg2> m_arg2Impl;
- evaluator<Arg3> m_arg3Impl;
+ // this helper permits to completely eliminate the functor if it is empty
+ struct Data
+ {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Data(const XprType& xpr) : op(xpr.functor()), arg1Impl(xpr.arg1()), arg2Impl(xpr.arg2()), arg3Impl(xpr.arg3()) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TernaryOp& func() const { return op; }
+ TernaryOp op;
+ evaluator<Arg1> arg1Impl;
+ evaluator<Arg2> arg2Impl;
+ evaluator<Arg3> arg3Impl;
+ };
+
+ Data m_d;
};
// -------------------- CwiseBinaryOp --------------------
@@ -667,8 +724,9 @@
{
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
typedef binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > Base;
-
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& xpr) : Base(xpr) {}
};
template<typename BinaryOp, typename Lhs, typename Rhs>
@@ -676,10 +734,10 @@
: evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
{
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
-
+
enum {
- CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
-
+ CoeffReadCost = int(evaluator<Lhs>::CoeffReadCost) + int(evaluator<Rhs>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost),
+
LhsFlags = evaluator<Lhs>::Flags,
RhsFlags = evaluator<Rhs>::Flags,
SameType = is_same<typename Lhs::Scalar,typename Rhs::Scalar>::value,
@@ -696,10 +754,8 @@
Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator<Lhs>::Alignment,evaluator<Rhs>::Alignment)
};
- EIGEN_DEVICE_FUNC explicit binary_evaluator(const XprType& xpr)
- : m_functor(xpr.functor()),
- m_lhsImpl(xpr.lhs()),
- m_rhsImpl(xpr.rhs())
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit binary_evaluator(const XprType& xpr) : m_d(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost);
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
@@ -710,35 +766,46 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
- return m_functor(m_lhsImpl.coeff(row, col), m_rhsImpl.coeff(row, col));
+ return m_d.func()(m_d.lhsImpl.coeff(row, col), m_d.rhsImpl.coeff(row, col));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
- return m_functor(m_lhsImpl.coeff(index), m_rhsImpl.coeff(index));
+ return m_d.func()(m_d.lhsImpl.coeff(index), m_d.rhsImpl.coeff(index));
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index row, Index col) const
{
- return m_functor.packetOp(m_lhsImpl.template packet<LoadMode,PacketType>(row, col),
- m_rhsImpl.template packet<LoadMode,PacketType>(row, col));
+ return m_d.func().packetOp(m_d.lhsImpl.template packet<LoadMode,PacketType>(row, col),
+ m_d.rhsImpl.template packet<LoadMode,PacketType>(row, col));
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index index) const
{
- return m_functor.packetOp(m_lhsImpl.template packet<LoadMode,PacketType>(index),
- m_rhsImpl.template packet<LoadMode,PacketType>(index));
+ return m_d.func().packetOp(m_d.lhsImpl.template packet<LoadMode,PacketType>(index),
+ m_d.rhsImpl.template packet<LoadMode,PacketType>(index));
}
protected:
- const BinaryOp m_functor;
- evaluator<Lhs> m_lhsImpl;
- evaluator<Rhs> m_rhsImpl;
+
+ // this helper permits to completely eliminate the functor if it is empty
+ struct Data
+ {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Data(const XprType& xpr) : op(xpr.functor()), lhsImpl(xpr.lhs()), rhsImpl(xpr.rhs()) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const BinaryOp& func() const { return op; }
+ BinaryOp op;
+ evaluator<Lhs> lhsImpl;
+ evaluator<Rhs> rhsImpl;
+ };
+
+ Data m_d;
};
// -------------------- CwiseUnaryView --------------------
@@ -748,18 +815,16 @@
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType> >
{
typedef CwiseUnaryView<UnaryOp, ArgType> XprType;
-
+
enum {
- CoeffReadCost = evaluator<ArgType>::CoeffReadCost + functor_traits<UnaryOp>::Cost,
-
+ CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),
+
Flags = (evaluator<ArgType>::Flags & (HereditaryBits | LinearAccessBit | DirectAccessBit)),
-
+
Alignment = 0 // FIXME it is not very clear why alignment is necessarily lost...
};
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& op)
- : m_unaryOp(op.functor()),
- m_argImpl(op.nestedExpression())
+ EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& op) : m_d(op)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<UnaryOp>::Cost);
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
@@ -771,30 +836,41 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
- return m_unaryOp(m_argImpl.coeff(row, col));
+ return m_d.func()(m_d.argImpl.coeff(row, col));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
- return m_unaryOp(m_argImpl.coeff(index));
+ return m_d.func()(m_d.argImpl.coeff(index));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index row, Index col)
{
- return m_unaryOp(m_argImpl.coeffRef(row, col));
+ return m_d.func()(m_d.argImpl.coeffRef(row, col));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index index)
{
- return m_unaryOp(m_argImpl.coeffRef(index));
+ return m_d.func()(m_d.argImpl.coeffRef(index));
}
protected:
- const UnaryOp m_unaryOp;
- evaluator<ArgType> m_argImpl;
+
+ // this helper permits to completely eliminate the functor if it is empty
+ struct Data
+ {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Data(const XprType& xpr) : op(xpr.functor()), argImpl(xpr.nestedExpression()) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const UnaryOp& func() const { return op; }
+ UnaryOp op;
+ evaluator<ArgType> argImpl;
+ };
+
+ Data m_d;
};
// -------------------- Map --------------------
@@ -811,14 +887,15 @@
typedef typename XprType::PointerType PointerType;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
-
+
enum {
IsRowMajor = XprType::RowsAtCompileTime,
ColsAtCompileTime = XprType::ColsAtCompileTime,
CoeffReadCost = NumTraits<Scalar>::ReadCost
};
- EIGEN_DEVICE_FUNC explicit mapbase_evaluator(const XprType& map)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit mapbase_evaluator(const XprType& map)
: m_data(const_cast<PointerType>(map.data())),
m_innerStride(map.innerStride()),
m_outerStride(map.outerStride())
@@ -882,17 +959,21 @@
internal::pstoret<Scalar, PacketType, StoreMode>(m_data + index * m_innerStride.value(), x);
}
protected:
- EIGEN_DEVICE_FUNC
- inline Index rowStride() const { return XprType::IsRowMajor ? m_outerStride.value() : m_innerStride.value(); }
- EIGEN_DEVICE_FUNC
- inline Index colStride() const { return XprType::IsRowMajor ? m_innerStride.value() : m_outerStride.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rowStride() const EIGEN_NOEXCEPT {
+ return XprType::IsRowMajor ? m_outerStride.value() : m_innerStride.value();
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index colStride() const EIGEN_NOEXCEPT {
+ return XprType::IsRowMajor ? m_innerStride.value() : m_outerStride.value();
+ }
PointerType m_data;
const internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_innerStride;
const internal::variable_if_dynamic<Index, XprType::OuterStrideAtCompileTime> m_outerStride;
};
-template<typename PlainObjectType, int MapOptions, typename StrideType>
+template<typename PlainObjectType, int MapOptions, typename StrideType>
struct evaluator<Map<PlainObjectType, MapOptions, StrideType> >
: public mapbase_evaluator<Map<PlainObjectType, MapOptions, StrideType>, PlainObjectType>
{
@@ -900,7 +981,7 @@
typedef typename XprType::Scalar Scalar;
// TODO: should check for smaller packet types once we can handle multi-sized packet types
typedef typename packet_traits<Scalar>::type PacketScalar;
-
+
enum {
InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0
? int(PlainObjectType::InnerStrideAtCompileTime)
@@ -912,34 +993,35 @@
HasNoOuterStride = StrideType::OuterStrideAtCompileTime == 0,
HasNoStride = HasNoInnerStride && HasNoOuterStride,
IsDynamicSize = PlainObjectType::SizeAtCompileTime==Dynamic,
-
+
PacketAccessMask = bool(HasNoInnerStride) ? ~int(0) : ~int(PacketAccessBit),
LinearAccessMask = bool(HasNoStride) || bool(PlainObjectType::IsVectorAtCompileTime) ? ~int(0) : ~int(LinearAccessBit),
Flags = int( evaluator<PlainObjectType>::Flags) & (LinearAccessMask&PacketAccessMask),
-
+
Alignment = int(MapOptions)&int(AlignedMask)
};
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& map)
- : mapbase_evaluator<XprType, PlainObjectType>(map)
+ : mapbase_evaluator<XprType, PlainObjectType>(map)
{ }
};
// -------------------- Ref --------------------
-template<typename PlainObjectType, int RefOptions, typename StrideType>
+template<typename PlainObjectType, int RefOptions, typename StrideType>
struct evaluator<Ref<PlainObjectType, RefOptions, StrideType> >
: public mapbase_evaluator<Ref<PlainObjectType, RefOptions, StrideType>, PlainObjectType>
{
typedef Ref<PlainObjectType, RefOptions, StrideType> XprType;
-
+
enum {
Flags = evaluator<Map<PlainObjectType, RefOptions, StrideType> >::Flags,
Alignment = evaluator<Map<PlainObjectType, RefOptions, StrideType> >::Alignment
};
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& ref)
- : mapbase_evaluator<XprType, PlainObjectType>(ref)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& ref)
+ : mapbase_evaluator<XprType, PlainObjectType>(ref)
{ }
};
@@ -947,8 +1029,8 @@
template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel,
bool HasDirectAccess = internal::has_direct_access<ArgType>::ret> struct block_evaluator;
-
-template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel>
+
+template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel>
struct evaluator<Block<ArgType, BlockRows, BlockCols, InnerPanel> >
: block_evaluator<ArgType, BlockRows, BlockCols, InnerPanel>
{
@@ -956,15 +1038,15 @@
typedef typename XprType::Scalar Scalar;
// TODO: should check for smaller packet types once we can handle multi-sized packet types
typedef typename packet_traits<Scalar>::type PacketScalar;
-
+
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
-
+
RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
MaxRowsAtCompileTime = traits<XprType>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = traits<XprType>::MaxColsAtCompileTime,
-
+
ArgTypeIsRowMajor = (int(evaluator<ArgType>::Flags)&RowMajorBit) != 0,
IsRowMajor = (MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1) ? 1
: (MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1) ? 0
@@ -978,14 +1060,14 @@
? int(outer_stride_at_compile_time<ArgType>::ret)
: int(inner_stride_at_compile_time<ArgType>::ret),
MaskPacketAccessBit = (InnerStrideAtCompileTime == 1 || HasSameStorageOrderAsArgType) ? PacketAccessBit : 0,
-
- FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1 || (InnerPanel && (evaluator<ArgType>::Flags&LinearAccessBit))) ? LinearAccessBit : 0,
+
+ FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1 || (InnerPanel && (evaluator<ArgType>::Flags&LinearAccessBit))) ? LinearAccessBit : 0,
FlagsRowMajorBit = XprType::Flags&RowMajorBit,
Flags0 = evaluator<ArgType>::Flags & ( (HereditaryBits & ~RowMajorBit) |
DirectAccessBit |
MaskPacketAccessBit),
Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit,
-
+
PacketAlignment = unpacket_traits<PacketScalar>::alignment,
Alignment0 = (InnerPanel && (OuterStrideAtCompileTime!=Dynamic)
&& (OuterStrideAtCompileTime!=0)
@@ -993,7 +1075,8 @@
Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator<ArgType>::Alignment, Alignment0)
};
typedef block_evaluator<ArgType, BlockRows, BlockCols, InnerPanel> block_evaluator_type;
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& block) : block_evaluator_type(block)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& block) : block_evaluator_type(block)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
@@ -1006,8 +1089,9 @@
{
typedef Block<ArgType, BlockRows, BlockCols, InnerPanel> XprType;
- EIGEN_DEVICE_FUNC explicit block_evaluator(const XprType& block)
- : unary_evaluator<XprType>(block)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit block_evaluator(const XprType& block)
+ : unary_evaluator<XprType>(block)
{}
};
@@ -1017,79 +1101,74 @@
{
typedef Block<ArgType, BlockRows, BlockCols, InnerPanel> XprType;
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& block)
- : m_argImpl(block.nestedExpression()),
- m_startRow(block.startRow()),
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit unary_evaluator(const XprType& block)
+ : m_argImpl(block.nestedExpression()),
+ m_startRow(block.startRow()),
m_startCol(block.startCol()),
- m_linear_offset(InnerPanel?(XprType::IsRowMajor ? block.startRow()*block.cols() : block.startCol()*block.rows()):0)
+ m_linear_offset(ForwardLinearAccess?(ArgType::IsRowMajor ? block.startRow()*block.nestedExpression().cols() + block.startCol() : block.startCol()*block.nestedExpression().rows() + block.startRow()):0)
{ }
-
+
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
enum {
RowsAtCompileTime = XprType::RowsAtCompileTime,
- ForwardLinearAccess = InnerPanel && bool(evaluator<ArgType>::Flags&LinearAccessBit)
+ ForwardLinearAccess = (InnerPanel || int(XprType::IsRowMajor)==int(ArgType::IsRowMajor)) && bool(evaluator<ArgType>::Flags&LinearAccessBit)
};
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
- {
- return m_argImpl.coeff(m_startRow.value() + row, m_startCol.value() + col);
+ {
+ return m_argImpl.coeff(m_startRow.value() + row, m_startCol.value() + col);
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
- {
- if (ForwardLinearAccess)
- return m_argImpl.coeff(m_linear_offset.value() + index);
- else
- return coeff(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
+ {
+ return linear_coeff_impl(index, bool_constant<ForwardLinearAccess>());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index row, Index col)
- {
- return m_argImpl.coeffRef(m_startRow.value() + row, m_startCol.value() + col);
+ {
+ return m_argImpl.coeffRef(m_startRow.value() + row, m_startCol.value() + col);
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar& coeffRef(Index index)
- {
- if (ForwardLinearAccess)
- return m_argImpl.coeffRef(m_linear_offset.value() + index);
- else
- return coeffRef(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
- }
-
- template<int LoadMode, typename PacketType>
- EIGEN_STRONG_INLINE
- PacketType packet(Index row, Index col) const
- {
- return m_argImpl.template packet<LoadMode,PacketType>(m_startRow.value() + row, m_startCol.value() + col);
+ {
+ return linear_coeffRef_impl(index, bool_constant<ForwardLinearAccess>());
}
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
- PacketType packet(Index index) const
- {
+ PacketType packet(Index row, Index col) const
+ {
+ return m_argImpl.template packet<LoadMode,PacketType>(m_startRow.value() + row, m_startCol.value() + col);
+ }
+
+ template<int LoadMode, typename PacketType>
+ EIGEN_STRONG_INLINE
+ PacketType packet(Index index) const
+ {
if (ForwardLinearAccess)
return m_argImpl.template packet<LoadMode,PacketType>(m_linear_offset.value() + index);
else
return packet<LoadMode,PacketType>(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0);
}
-
+
template<int StoreMode, typename PacketType>
EIGEN_STRONG_INLINE
- void writePacket(Index row, Index col, const PacketType& x)
+ void writePacket(Index row, Index col, const PacketType& x)
{
- return m_argImpl.template writePacket<StoreMode,PacketType>(m_startRow.value() + row, m_startCol.value() + col, x);
+ return m_argImpl.template writePacket<StoreMode,PacketType>(m_startRow.value() + row, m_startCol.value() + col, x);
}
-
+
template<int StoreMode, typename PacketType>
EIGEN_STRONG_INLINE
- void writePacket(Index index, const PacketType& x)
+ void writePacket(Index index, const PacketType& x)
{
if (ForwardLinearAccess)
return m_argImpl.template writePacket<StoreMode,PacketType>(m_linear_offset.value() + index, x);
@@ -1098,18 +1177,40 @@
RowsAtCompileTime == 1 ? index : 0,
x);
}
-
+
protected:
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ CoeffReturnType linear_coeff_impl(Index index, internal::true_type /* ForwardLinearAccess */) const
+ {
+ return m_argImpl.coeff(m_linear_offset.value() + index);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ CoeffReturnType linear_coeff_impl(Index index, internal::false_type /* not ForwardLinearAccess */) const
+ {
+ return coeff(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Scalar& linear_coeffRef_impl(Index index, internal::true_type /* ForwardLinearAccess */)
+ {
+ return m_argImpl.coeffRef(m_linear_offset.value() + index);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Scalar& linear_coeffRef_impl(Index index, internal::false_type /* not ForwardLinearAccess */)
+ {
+ return coeffRef(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
+ }
+
evaluator<ArgType> m_argImpl;
const variable_if_dynamic<Index, (ArgType::RowsAtCompileTime == 1 && BlockRows==1) ? 0 : Dynamic> m_startRow;
const variable_if_dynamic<Index, (ArgType::ColsAtCompileTime == 1 && BlockCols==1) ? 0 : Dynamic> m_startCol;
- const variable_if_dynamic<Index, InnerPanel ? Dynamic : 0> m_linear_offset;
+ const variable_if_dynamic<Index, ForwardLinearAccess ? Dynamic : 0> m_linear_offset;
};
-// TODO: This evaluator does not actually use the child evaluator;
+// TODO: This evaluator does not actually use the child evaluator;
// all action is via the data() as returned by the Block expression.
-template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel>
+template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel>
struct block_evaluator<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDirectAccess */ true>
: mapbase_evaluator<Block<ArgType, BlockRows, BlockCols, InnerPanel>,
typename Block<ArgType, BlockRows, BlockCols, InnerPanel>::PlainObject>
@@ -1117,8 +1218,9 @@
typedef Block<ArgType, BlockRows, BlockCols, InnerPanel> XprType;
typedef typename XprType::Scalar Scalar;
- EIGEN_DEVICE_FUNC explicit block_evaluator(const XprType& block)
- : mapbase_evaluator<XprType, typename XprType::PlainObject>(block)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit block_evaluator(const XprType& block)
+ : mapbase_evaluator<XprType, typename XprType::PlainObject>(block)
{
// TODO: for the 3.3 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime
eigen_assert(((internal::UIntPtr(block.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned");
@@ -1141,18 +1243,19 @@
evaluator<ElseMatrixType>::CoeffReadCost),
Flags = (unsigned int)evaluator<ThenMatrixType>::Flags & evaluator<ElseMatrixType>::Flags & HereditaryBits,
-
+
Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator<ThenMatrixType>::Alignment, evaluator<ElseMatrixType>::Alignment)
};
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& select)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& select)
: m_conditionImpl(select.conditionMatrix()),
m_thenImpl(select.thenMatrix()),
m_elseImpl(select.elseMatrix())
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
-
+
typedef typename XprType::CoeffReturnType CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -1172,7 +1275,7 @@
else
return m_elseImpl.coeff(index);
}
-
+
protected:
evaluator<ConditionMatrixType> m_conditionImpl;
evaluator<ThenMatrixType> m_thenImpl;
@@ -1182,7 +1285,7 @@
// -------------------- Replicate --------------------
-template<typename ArgType, int RowFactor, int ColFactor>
+template<typename ArgType, int RowFactor, int ColFactor>
struct unary_evaluator<Replicate<ArgType, RowFactor, ColFactor> >
: evaluator_base<Replicate<ArgType, RowFactor, ColFactor> >
{
@@ -1193,22 +1296,23 @@
};
typedef typename internal::nested_eval<ArgType,Factor>::type ArgTypeNested;
typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
-
+
enum {
CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
LinearAccessMask = XprType::IsVectorAtCompileTime ? LinearAccessBit : 0,
Flags = (evaluator<ArgTypeNestedCleaned>::Flags & (HereditaryBits|LinearAccessMask) & ~RowMajorBit) | (traits<XprType>::Flags & RowMajorBit),
-
+
Alignment = evaluator<ArgTypeNestedCleaned>::Alignment
};
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& replicate)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit unary_evaluator(const XprType& replicate)
: m_arg(replicate.nestedExpression()),
m_argImpl(m_arg),
m_rows(replicate.nestedExpression().rows()),
m_cols(replicate.nestedExpression().cols())
{}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
@@ -1219,10 +1323,10 @@
const Index actual_col = internal::traits<XprType>::ColsAtCompileTime==1 ? 0
: ColFactor==1 ? col
: col % m_cols.value();
-
+
return m_argImpl.coeff(actual_row, actual_col);
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
@@ -1230,7 +1334,7 @@
const Index actual_index = internal::traits<XprType>::RowsAtCompileTime==1
? (ColFactor==1 ? index : index%m_cols.value())
: (RowFactor==1 ? index : index%m_rows.value());
-
+
return m_argImpl.coeff(actual_index);
}
@@ -1247,7 +1351,7 @@
return m_argImpl.template packet<LoadMode,PacketType>(actual_row, actual_col);
}
-
+
template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE
PacketType packet(Index index) const
@@ -1258,7 +1362,7 @@
return m_argImpl.template packet<LoadMode,PacketType>(actual_index);
}
-
+
protected:
const ArgTypeNested m_arg;
evaluator<ArgTypeNestedCleaned> m_argImpl;
@@ -1266,64 +1370,6 @@
const variable_if_dynamic<Index, ArgType::ColsAtCompileTime> m_cols;
};
-
-// -------------------- PartialReduxExpr --------------------
-
-template< typename ArgType, typename MemberOp, int Direction>
-struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
- : evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
-{
- typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
- typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
- typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
- typedef typename ArgType::Scalar InputScalar;
- typedef typename XprType::Scalar Scalar;
- enum {
- TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(ArgType::ColsAtCompileTime)
- };
- typedef typename MemberOp::template Cost<InputScalar,int(TraversalSize)> CostOpType;
- enum {
- CoeffReadCost = TraversalSize==Dynamic ? HugeCost
- : TraversalSize * evaluator<ArgType>::CoeffReadCost + int(CostOpType::value),
-
- Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&(HereditaryBits&(~RowMajorBit))) | LinearAccessBit,
-
- Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
- };
-
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
- : m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
- {
- EIGEN_INTERNAL_CHECK_COST_VALUE(TraversalSize==Dynamic ? HugeCost : int(CostOpType::value));
- EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
- }
-
- typedef typename XprType::CoeffReturnType CoeffReturnType;
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const Scalar coeff(Index i, Index j) const
- {
- if (Direction==Vertical)
- return m_functor(m_arg.col(j));
- else
- return m_functor(m_arg.row(i));
- }
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const Scalar coeff(Index index) const
- {
- if (Direction==Vertical)
- return m_functor(m_arg.col(index));
- else
- return m_functor(m_arg.row(index));
- }
-
-protected:
- typename internal::add_const_on_value_type<ArgTypeNested>::type m_arg;
- const MemberOp m_functor;
-};
-
-
// -------------------- MatrixWrapper and ArrayWrapper --------------------
//
// evaluator_wrapper_base<T> is a common base class for the
@@ -1340,7 +1386,8 @@
Alignment = evaluator<ArgType>::Alignment
};
- EIGEN_DEVICE_FUNC explicit evaluator_wrapper_base(const ArgType& arg) : m_argImpl(arg) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator_wrapper_base(const ArgType& arg) : m_argImpl(arg) {}
typedef typename ArgType::Scalar Scalar;
typedef typename ArgType::CoeffReturnType CoeffReturnType;
@@ -1407,7 +1454,8 @@
{
typedef MatrixWrapper<TArgType> XprType;
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& wrapper)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit unary_evaluator(const XprType& wrapper)
: evaluator_wrapper_base<MatrixWrapper<TArgType> >(wrapper.nestedExpression())
{ }
};
@@ -1418,7 +1466,8 @@
{
typedef ArrayWrapper<TArgType> XprType;
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& wrapper)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit unary_evaluator(const XprType& wrapper)
: evaluator_wrapper_base<ArrayWrapper<TArgType> >(wrapper.nestedExpression())
{ }
};
@@ -1445,9 +1494,9 @@
ReversePacket = (Direction == BothDirections)
|| ((Direction == Vertical) && IsColMajor)
|| ((Direction == Horizontal) && IsRowMajor),
-
+
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
-
+
// let's enable LinearAccess only with vectorization because of the product overhead
// FIXME enable DirectAccess with negative strides?
Flags0 = evaluator<ArgType>::Flags,
@@ -1456,16 +1505,17 @@
? LinearAccessBit : 0,
Flags = int(Flags0) & (HereditaryBits | PacketAccessBit | LinearAccess),
-
+
Alignment = 0 // FIXME in some rare cases, Alignment could be preserved, like a Vector4f.
};
- EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& reverse)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit unary_evaluator(const XprType& reverse)
: m_argImpl(reverse.nestedExpression()),
m_rows(ReverseRow ? reverse.nestedExpression().rows() : 1),
m_cols(ReverseCol ? reverse.nestedExpression().cols() : 1)
{ }
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
@@ -1540,7 +1590,7 @@
m_argImpl.template writePacket<LoadMode>
(m_rows.value() * m_cols.value() - index - PacketSize, preverse(x));
}
-
+
protected:
evaluator<ArgType> m_argImpl;
@@ -1558,20 +1608,21 @@
: evaluator_base<Diagonal<ArgType, DiagIndex> >
{
typedef Diagonal<ArgType, DiagIndex> XprType;
-
+
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
-
+
Flags = (unsigned int)(evaluator<ArgType>::Flags & (HereditaryBits | DirectAccessBit) & ~RowMajorBit) | LinearAccessBit,
-
+
Alignment = 0
};
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType& diagonal)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit evaluator(const XprType& diagonal)
: m_argImpl(diagonal.nestedExpression()),
m_index(diagonal.index())
{ }
-
+
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
@@ -1604,8 +1655,10 @@
const internal::variable_if_dynamicindex<Index, XprType::DiagIndex> m_index;
private:
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rowOffset() const { return m_index.value() > 0 ? 0 : -m_index.value(); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index colOffset() const { return m_index.value() > 0 ? m_index.value() : 0; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rowOffset() const { return m_index.value() > 0 ? 0 : -m_index.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index colOffset() const { return m_index.value() > 0 ? m_index.value() : 0; }
};
@@ -1629,25 +1682,25 @@
: public dense_xpr_base<EvalToTemp<ArgType> >::type
{
public:
-
+
typedef typename dense_xpr_base<EvalToTemp>::type Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(EvalToTemp)
-
+
explicit EvalToTemp(const ArgType& arg)
: m_arg(arg)
{ }
-
+
const ArgType& arg() const
{
return m_arg;
}
- Index rows() const
+ EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
{
return m_arg.rows();
}
- Index cols() const
+ EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
{
return m_arg.cols();
}
@@ -1655,7 +1708,7 @@
private:
const ArgType& m_arg;
};
-
+
template<typename ArgType>
struct evaluator<EvalToTemp<ArgType> >
: public evaluator<typename ArgType::PlainObject>
@@ -1663,7 +1716,7 @@
typedef EvalToTemp<ArgType> XprType;
typedef typename ArgType::PlainObject PlainObject;
typedef evaluator<PlainObject> Base;
-
+
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
: m_result(xpr.arg())
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreIterators.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreIterators.h
index 4eb42b9..b967196 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreIterators.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CoreIterators.h
@@ -48,6 +48,11 @@
* Explicit zeros are not skipped over. To skip explicit zeros, see class SparseView
*/
EIGEN_STRONG_INLINE InnerIterator& operator++() { m_iter.operator++(); return *this; }
+ EIGEN_STRONG_INLINE InnerIterator& operator+=(Index i) { m_iter.operator+=(i); return *this; }
+ EIGEN_STRONG_INLINE InnerIterator operator+(Index i)
+ { InnerIterator result(*this); result+=i; return result; }
+
+
/// \returns the column or row index of the current coefficient.
EIGEN_STRONG_INLINE Index index() const { return m_iter.index(); }
/// \returns the row index of the current coefficient.
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseBinaryOp.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseBinaryOp.h
index a36765e..2202b1c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseBinaryOp.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseBinaryOp.h
@@ -74,7 +74,7 @@
* \sa MatrixBase::binaryExpr(const MatrixBase<OtherDerived> &,const CustomBinaryOp &) const, class CwiseUnaryOp, class CwiseNullaryOp
*/
template<typename BinaryOp, typename LhsType, typename RhsType>
-class CwiseBinaryOp :
+class CwiseBinaryOp :
public CwiseBinaryOpImpl<
BinaryOp, LhsType, RhsType,
typename internal::cwise_promote_storage_type<typename internal::traits<LhsType>::StorageKind,
@@ -83,7 +83,7 @@
internal::no_assignment_operator
{
public:
-
+
typedef typename internal::remove_all<BinaryOp>::type Functor;
typedef typename internal::remove_all<LhsType>::type Lhs;
typedef typename internal::remove_all<RhsType>::type Rhs;
@@ -100,8 +100,14 @@
typedef typename internal::remove_reference<LhsNested>::type _LhsNested;
typedef typename internal::remove_reference<RhsNested>::type _RhsNested;
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE CwiseBinaryOp(const Lhs& aLhs, const Rhs& aRhs, const BinaryOp& func = BinaryOp())
+#if EIGEN_COMP_MSVC && EIGEN_HAS_CXX11
+ //Required for Visual Studio or the Copy constructor will probably not get inlined!
+ EIGEN_STRONG_INLINE
+ CwiseBinaryOp(const CwiseBinaryOp<BinaryOp,LhsType,RhsType>&) = default;
+#endif
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ CwiseBinaryOp(const Lhs& aLhs, const Rhs& aRhs, const BinaryOp& func = BinaryOp())
: m_lhs(aLhs), m_rhs(aRhs), m_functor(func)
{
EIGEN_CHECK_BINARY_COMPATIBILIY(BinaryOp,typename Lhs::Scalar,typename Rhs::Scalar);
@@ -110,31 +116,25 @@
eigen_assert(aLhs.rows() == aRhs.rows() && aLhs.cols() == aRhs.cols());
}
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index rows() const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT {
// return the fixed size type if available to enable compile time optimizations
- if (internal::traits<typename internal::remove_all<LhsNested>::type>::RowsAtCompileTime==Dynamic)
- return m_rhs.rows();
- else
- return m_lhs.rows();
+ return internal::traits<typename internal::remove_all<LhsNested>::type>::RowsAtCompileTime==Dynamic ? m_rhs.rows() : m_lhs.rows();
}
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index cols() const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT {
// return the fixed size type if available to enable compile time optimizations
- if (internal::traits<typename internal::remove_all<LhsNested>::type>::ColsAtCompileTime==Dynamic)
- return m_rhs.cols();
- else
- return m_lhs.cols();
+ return internal::traits<typename internal::remove_all<LhsNested>::type>::ColsAtCompileTime==Dynamic ? m_rhs.cols() : m_lhs.cols();
}
/** \returns the left hand side nested expression */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const _LhsNested& lhs() const { return m_lhs; }
/** \returns the right hand side nested expression */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const _RhsNested& rhs() const { return m_rhs; }
/** \returns the functor representing the binary operation */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const BinaryOp& functor() const { return m_functor; }
protected:
@@ -158,7 +158,7 @@
*/
template<typename Derived>
template<typename OtherDerived>
-EIGEN_STRONG_INLINE Derived &
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived &
MatrixBase<Derived>::operator-=(const MatrixBase<OtherDerived> &other)
{
call_assignment(derived(), other.derived(), internal::sub_assign_op<Scalar,typename OtherDerived::Scalar>());
@@ -171,7 +171,7 @@
*/
template<typename Derived>
template<typename OtherDerived>
-EIGEN_STRONG_INLINE Derived &
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived &
MatrixBase<Derived>::operator+=(const MatrixBase<OtherDerived>& other)
{
call_assignment(derived(), other.derived(), internal::add_assign_op<Scalar,typename OtherDerived::Scalar>());
@@ -181,4 +181,3 @@
} // end namespace Eigen
#endif // EIGEN_CWISE_BINARY_OP_H
-
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseNullaryOp.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseNullaryOp.h
index ddd607e..289ec51 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseNullaryOp.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseNullaryOp.h
@@ -74,10 +74,10 @@
&& (ColsAtCompileTime == Dynamic || ColsAtCompileTime == cols));
}
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index rows() const { return m_rows.value(); }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index cols() const { return m_cols.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const { return m_rows.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const { return m_cols.value(); }
/** \returns the functor representing the nullary operation */
EIGEN_DEVICE_FUNC
@@ -105,7 +105,12 @@
*/
template<typename Derived>
template<typename CustomNullaryOp>
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseNullaryOp<CustomNullaryOp, typename DenseBase<Derived>::PlainObject>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const CwiseNullaryOp<CustomNullaryOp,typename DenseBase<Derived>::PlainObject>
+#else
+const CwiseNullaryOp<CustomNullaryOp,PlainObject>
+#endif
DenseBase<Derived>::NullaryExpr(Index rows, Index cols, const CustomNullaryOp& func)
{
return CwiseNullaryOp<CustomNullaryOp, PlainObject>(rows, cols, func);
@@ -126,12 +131,17 @@
*
* Here is an example with C++11 random generators: \include random_cpp11.cpp
* Output: \verbinclude random_cpp11.out
- *
+ *
* \sa class CwiseNullaryOp
*/
template<typename Derived>
template<typename CustomNullaryOp>
-EIGEN_STRONG_INLINE const CwiseNullaryOp<CustomNullaryOp, typename DenseBase<Derived>::PlainObject>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const CwiseNullaryOp<CustomNullaryOp, typename DenseBase<Derived>::PlainObject>
+#else
+const CwiseNullaryOp<CustomNullaryOp, PlainObject>
+#endif
DenseBase<Derived>::NullaryExpr(Index size, const CustomNullaryOp& func)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
@@ -150,7 +160,12 @@
*/
template<typename Derived>
template<typename CustomNullaryOp>
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseNullaryOp<CustomNullaryOp, typename DenseBase<Derived>::PlainObject>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const CwiseNullaryOp<CustomNullaryOp, typename DenseBase<Derived>::PlainObject>
+#else
+const CwiseNullaryOp<CustomNullaryOp, PlainObject>
+#endif
DenseBase<Derived>::NullaryExpr(const CustomNullaryOp& func)
{
return CwiseNullaryOp<CustomNullaryOp, PlainObject>(RowsAtCompileTime, ColsAtCompileTime, func);
@@ -170,7 +185,7 @@
* \sa class CwiseNullaryOp
*/
template<typename Derived>
-EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ConstantReturnType
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::ConstantReturnType
DenseBase<Derived>::Constant(Index rows, Index cols, const Scalar& value)
{
return DenseBase<Derived>::NullaryExpr(rows, cols, internal::scalar_constant_op<Scalar>(value));
@@ -217,27 +232,32 @@
/** \deprecated because of accuracy loss. In Eigen 3.3, it is an alias for LinSpaced(Index,const Scalar&,const Scalar&)
*
- * \sa LinSpaced(Index,Scalar,Scalar), setLinSpaced(Index,const Scalar&,const Scalar&)
+ * \only_for_vectors
+ *
+ * Example: \include DenseBase_LinSpaced_seq_deprecated.cpp
+ * Output: \verbinclude DenseBase_LinSpaced_seq_deprecated.out
+ *
+ * \sa LinSpaced(Index,const Scalar&, const Scalar&), setLinSpaced(Index,const Scalar&,const Scalar&)
*/
template<typename Derived>
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::RandomAccessLinSpacedReturnType
+EIGEN_DEPRECATED EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::RandomAccessLinSpacedReturnType
DenseBase<Derived>::LinSpaced(Sequential_t, Index size, const Scalar& low, const Scalar& high)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return DenseBase<Derived>::NullaryExpr(size, internal::linspaced_op<Scalar,PacketScalar>(low,high,size));
+ return DenseBase<Derived>::NullaryExpr(size, internal::linspaced_op<Scalar>(low,high,size));
}
/** \deprecated because of accuracy loss. In Eigen 3.3, it is an alias for LinSpaced(const Scalar&,const Scalar&)
*
- * \sa LinSpaced(Scalar,Scalar)
+ * \sa LinSpaced(const Scalar&, const Scalar&)
*/
template<typename Derived>
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::RandomAccessLinSpacedReturnType
+EIGEN_DEPRECATED EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::RandomAccessLinSpacedReturnType
DenseBase<Derived>::LinSpaced(Sequential_t, const Scalar& low, const Scalar& high)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
EIGEN_STATIC_ASSERT_FIXED_SIZE(Derived)
- return DenseBase<Derived>::NullaryExpr(Derived::SizeAtCompileTime, internal::linspaced_op<Scalar,PacketScalar>(low,high,Derived::SizeAtCompileTime));
+ return DenseBase<Derived>::NullaryExpr(Derived::SizeAtCompileTime, internal::linspaced_op<Scalar>(low,high,Derived::SizeAtCompileTime));
}
/**
@@ -268,7 +288,7 @@
DenseBase<Derived>::LinSpaced(Index size, const Scalar& low, const Scalar& high)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return DenseBase<Derived>::NullaryExpr(size, internal::linspaced_op<Scalar,PacketScalar>(low,high,size));
+ return DenseBase<Derived>::NullaryExpr(size, internal::linspaced_op<Scalar>(low,high,size));
}
/**
@@ -281,7 +301,7 @@
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
EIGEN_STATIC_ASSERT_FIXED_SIZE(Derived)
- return DenseBase<Derived>::NullaryExpr(Derived::SizeAtCompileTime, internal::linspaced_op<Scalar,PacketScalar>(low,high,Derived::SizeAtCompileTime));
+ return DenseBase<Derived>::NullaryExpr(Derived::SizeAtCompileTime, internal::linspaced_op<Scalar>(low,high,Derived::SizeAtCompileTime));
}
/** \returns true if all coefficients in this matrix are approximately equal to \a val, to within precision \a prec */
@@ -363,6 +383,33 @@
return setConstant(val);
}
+/** Resizes to the given size, changing only the number of columns, and sets all
+ * coefficients in this expression to the given value \a val. For the parameter
+ * of type NoChange_t, just pass the special value \c NoChange.
+ *
+ * \sa MatrixBase::setConstant(const Scalar&), setConstant(Index,const Scalar&), class CwiseNullaryOp, MatrixBase::Constant(const Scalar&)
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setConstant(NoChange_t, Index cols, const Scalar& val)
+{
+ return setConstant(rows(), cols, val);
+}
+
+/** Resizes to the given size, changing only the number of rows, and sets all
+ * coefficients in this expression to the given value \a val. For the parameter
+ * of type NoChange_t, just pass the special value \c NoChange.
+ *
+ * \sa MatrixBase::setConstant(const Scalar&), setConstant(Index,const Scalar&), class CwiseNullaryOp, MatrixBase::Constant(const Scalar&)
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setConstant(Index rows, NoChange_t, const Scalar& val)
+{
+ return setConstant(rows, cols(), val);
+}
+
+
/**
* \brief Sets a linearly spaced vector.
*
@@ -383,7 +430,7 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setLinSpaced(Index newSize, const Scalar& low, const Scalar& high)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return derived() = Derived::NullaryExpr(newSize, internal::linspaced_op<Scalar,PacketScalar>(low,high,newSize));
+ return derived() = Derived::NullaryExpr(newSize, internal::linspaced_op<Scalar>(low,high,newSize));
}
/**
@@ -536,6 +583,32 @@
return setConstant(Scalar(0));
}
+/** Resizes to the given size, changing only the number of columns, and sets all
+ * coefficients in this expression to zero. For the parameter of type NoChange_t,
+ * just pass the special value \c NoChange.
+ *
+ * \sa DenseBase::setZero(), setZero(Index), setZero(Index, Index), setZero(Index, NoChange_t), class CwiseNullaryOp, DenseBase::Zero()
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setZero(NoChange_t, Index cols)
+{
+ return setZero(rows(), cols);
+}
+
+/** Resizes to the given size, changing only the number of rows, and sets all
+ * coefficients in this expression to zero. For the parameter of type NoChange_t,
+ * just pass the special value \c NoChange.
+ *
+ * \sa DenseBase::setZero(), setZero(Index), setZero(Index, Index), setZero(NoChange_t, Index), class CwiseNullaryOp, DenseBase::Zero()
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setZero(Index rows, NoChange_t)
+{
+ return setZero(rows, cols());
+}
+
// ones:
/** \returns an expression of a matrix where all coefficients equal one.
@@ -662,6 +735,32 @@
return setConstant(Scalar(1));
}
+/** Resizes to the given size, changing only the number of rows, and sets all
+ * coefficients in this expression to one. For the parameter of type NoChange_t,
+ * just pass the special value \c NoChange.
+ *
+ * \sa MatrixBase::setOnes(), setOnes(Index), setOnes(Index, Index), setOnes(NoChange_t, Index), class CwiseNullaryOp, MatrixBase::Ones()
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setOnes(Index rows, NoChange_t)
+{
+ return setOnes(rows, cols());
+}
+
+/** Resizes to the given size, changing only the number of columns, and sets all
+ * coefficients in this expression to one. For the parameter of type NoChange_t,
+ * just pass the special value \c NoChange.
+ *
+ * \sa MatrixBase::setOnes(), setOnes(Index), setOnes(Index, Index), setOnes(Index, NoChange_t) class CwiseNullaryOp, MatrixBase::Ones()
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setOnes(NoChange_t, Index cols)
+{
+ return setOnes(rows(), cols);
+}
+
// Identity:
/** \returns an expression of the identity matrix (not necessarily square).
@@ -861,6 +960,42 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename MatrixBase<Derived>::BasisReturnType MatrixBase<Derived>::UnitW()
{ return Derived::Unit(3); }
+/** \brief Set the coefficients of \c *this to the i-th unit (basis) vector
+ *
+ * \param i index of the unique coefficient to be set to 1
+ *
+ * \only_for_vectors
+ *
+ * \sa MatrixBase::setIdentity(), class CwiseNullaryOp, MatrixBase::Unit(Index,Index)
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& MatrixBase<Derived>::setUnit(Index i)
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
+ eigen_assert(i<size());
+ derived().setZero();
+ derived().coeffRef(i) = Scalar(1);
+ return derived();
+}
+
+/** \brief Resizes to the given \a newSize, and writes the i-th unit (basis) vector into *this.
+ *
+ * \param newSize the new size of the vector
+ * \param i index of the unique coefficient to be set to 1
+ *
+ * \only_for_vectors
+ *
+ * \sa MatrixBase::setIdentity(), class CwiseNullaryOp, MatrixBase::Unit(Index,Index)
+ */
+template<typename Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& MatrixBase<Derived>::setUnit(Index newSize, Index i)
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
+ eigen_assert(i<newSize);
+ derived().resize(newSize);
+ return setUnit(i);
+}
+
} // end namespace Eigen
#endif // EIGEN_CWISE_NULLARY_OP_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryOp.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryOp.h
index 1d2dd19..e68c4f7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryOp.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryOp.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_CWISE_UNARY_OP_H
#define EIGEN_CWISE_UNARY_OP_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
template<typename UnaryOp, typename XprType>
@@ -24,7 +24,7 @@
typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
enum {
- Flags = _XprTypeNested::Flags & RowMajorBit
+ Flags = _XprTypeNested::Flags & RowMajorBit
};
};
}
@@ -65,10 +65,10 @@
explicit CwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
: m_xpr(xpr), m_functor(func) {}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- Index rows() const { return m_xpr.rows(); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- Index cols() const { return m_xpr.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
/** \returns the functor representing the unary operation */
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryView.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryView.h
index 5a30fa8..a06d762 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryView.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/CwiseUnaryView.h
@@ -64,24 +64,26 @@
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
- explicit inline CwiseUnaryView(MatrixType& mat, const ViewOp& func = ViewOp())
+ explicit EIGEN_DEVICE_FUNC inline CwiseUnaryView(MatrixType& mat, const ViewOp& func = ViewOp())
: m_matrix(mat), m_functor(func) {}
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryView)
- EIGEN_STRONG_INLINE Index rows() const { return m_matrix.rows(); }
- EIGEN_STRONG_INLINE Index cols() const { return m_matrix.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
/** \returns the functor representing unary operation */
- const ViewOp& functor() const { return m_functor; }
+ EIGEN_DEVICE_FUNC const ViewOp& functor() const { return m_functor; }
/** \returns the nested expression */
- const typename internal::remove_all<MatrixTypeNested>::type&
+ EIGEN_DEVICE_FUNC const typename internal::remove_all<MatrixTypeNested>::type&
nestedExpression() const { return m_matrix; }
/** \returns the nested expression */
- typename internal::remove_reference<MatrixTypeNested>::type&
- nestedExpression() { return m_matrix.const_cast_derived(); }
+ EIGEN_DEVICE_FUNC typename internal::remove_reference<MatrixTypeNested>::type&
+ nestedExpression() { return m_matrix; }
protected:
MatrixTypeNested m_matrix;
@@ -108,16 +110,16 @@
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
-
+
EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeff(0)); }
- EIGEN_DEVICE_FUNC inline Index innerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
{
return derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
}
- EIGEN_DEVICE_FUNC inline Index outerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
{
return derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseBase.h
index c27a8a8..9b16db6 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseBase.h
@@ -14,15 +14,15 @@
namespace Eigen {
namespace internal {
-
+
// The index type defined by EIGEN_DEFAULT_DENSE_INDEX_TYPE must be a signed type.
// This dummy function simply aims at checking that at compile time.
static inline void check_DenseIndex_is_signed() {
- EIGEN_STATIC_ASSERT(NumTraits<DenseIndex>::IsSigned,THE_INDEX_TYPE_MUST_BE_A_SIGNED_TYPE);
+ EIGEN_STATIC_ASSERT(NumTraits<DenseIndex>::IsSigned,THE_INDEX_TYPE_MUST_BE_A_SIGNED_TYPE)
}
} // end namespace internal
-
+
/** \class DenseBase
* \ingroup Core_Module
*
@@ -40,7 +40,7 @@
*/
template<typename Derived> class DenseBase
#ifndef EIGEN_PARSED_BY_DOXYGEN
- : public DenseCoeffsBase<Derived>
+ : public DenseCoeffsBase<Derived, internal::accessors_level<Derived>::value>
#else
: public DenseCoeffsBase<Derived,DirectWriteAccessors>
#endif // not EIGEN_PARSED_BY_DOXYGEN
@@ -64,14 +64,14 @@
/** The numeric type of the expression' coefficients, e.g. float, double, int or std::complex<float>, etc. */
typedef typename internal::traits<Derived>::Scalar Scalar;
-
+
/** The numeric type of the expression' coefficients, e.g. float, double, int or std::complex<float>, etc.
*
* It is an alias for the Scalar type */
typedef Scalar value_type;
-
+
typedef typename NumTraits<Scalar>::Real RealScalar;
- typedef DenseCoeffsBase<Derived> Base;
+ typedef DenseCoeffsBase<Derived, internal::accessors_level<Derived>::value> Base;
using Base::derived;
using Base::const_cast_derived;
@@ -150,13 +150,18 @@
* \sa SizeAtCompileTime, MaxRowsAtCompileTime, MaxColsAtCompileTime
*/
- IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
- || internal::traits<Derived>::MaxColsAtCompileTime == 1,
+ IsVectorAtCompileTime = internal::traits<Derived>::RowsAtCompileTime == 1
+ || internal::traits<Derived>::ColsAtCompileTime == 1,
/**< This is set to true if either the number of rows or the number of
* columns is known at compile-time to be equal to 1. Indeed, in that case,
* we are dealing with a column-vector (if there is only one column) or with
* a row-vector (if there is only one row). */
+ NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2,
+ /**< This value is equal to Tensor::NumDimensions, i.e. 0 for scalars, 1 for vectors,
+ * and 2 for matrices.
+ */
+
Flags = internal::traits<Derived>::Flags,
/**< This stores expression \ref flags flags which may or may not be inherited by new expressions
* constructed from this one. See the \ref flags "list of flags".
@@ -170,11 +175,11 @@
InnerStrideAtCompileTime = internal::inner_stride_at_compile_time<Derived>::ret,
OuterStrideAtCompileTime = internal::outer_stride_at_compile_time<Derived>::ret
};
-
+
typedef typename internal::find_best_packet<Scalar,SizeAtCompileTime>::type PacketScalar;
enum { IsPlainObjectBase = 0 };
-
+
/** The plain matrix type corresponding to this expression.
* \sa PlainObject */
typedef Matrix<typename internal::traits<Derived>::Scalar,
@@ -184,7 +189,7 @@
internal::traits<Derived>::MaxRowsAtCompileTime,
internal::traits<Derived>::MaxColsAtCompileTime
> PlainMatrix;
-
+
/** The plain array type corresponding to this expression.
* \sa PlainObject */
typedef Array<typename internal::traits<Derived>::Scalar,
@@ -206,7 +211,7 @@
/** \returns the number of nonzero coefficients which is in practice the number
* of stored coefficients. */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index nonZeros() const { return size(); }
/** \returns the outer size.
@@ -214,7 +219,7 @@
* \note For a vector, this returns just 1. For a matrix (non-vector), this is the major dimension
* with respect to the \ref TopicStorageOrders "storage order", i.e., the number of columns for a
* column-major matrix, and the number of rows for a row-major matrix. */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index outerSize() const
{
return IsVectorAtCompileTime ? 1
@@ -224,9 +229,9 @@
/** \returns the inner size.
*
* \note For a vector, this is just the size. For a matrix (non-vector), this is the minor dimension
- * with respect to the \ref TopicStorageOrders "storage order", i.e., the number of rows for a
+ * with respect to the \ref TopicStorageOrders "storage order", i.e., the number of rows for a
* column-major matrix, and the number of columns for a row-major matrix. */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index innerSize() const
{
return IsVectorAtCompileTime ? this->size()
@@ -261,9 +266,9 @@
/** \internal Represents a matrix with all coefficients equal to one another*/
typedef CwiseNullaryOp<internal::scalar_constant_op<Scalar>,PlainObject> ConstantReturnType;
/** \internal \deprecated Represents a vector with linearly spaced coefficients that allows sequential access only. */
- typedef CwiseNullaryOp<internal::linspaced_op<Scalar,PacketScalar>,PlainObject> SequentialLinSpacedReturnType;
+ EIGEN_DEPRECATED typedef CwiseNullaryOp<internal::linspaced_op<Scalar>,PlainObject> SequentialLinSpacedReturnType;
/** \internal Represents a vector with linearly spaced coefficients that allows random access. */
- typedef CwiseNullaryOp<internal::linspaced_op<Scalar,PacketScalar>,PlainObject> RandomAccessLinSpacedReturnType;
+ typedef CwiseNullaryOp<internal::linspaced_op<Scalar>,PlainObject> RandomAccessLinSpacedReturnType;
/** \internal the return type of MatrixBase::eigenvalues() */
typedef Matrix<typename NumTraits<typename internal::traits<Derived>::Scalar>::Real, internal::traits<Derived>::ColsAtCompileTime, 1> EigenvaluesReturnType;
@@ -297,17 +302,17 @@
Derived& operator=(const ReturnByValue<OtherDerived>& func);
/** \internal
- * Copies \a other into *this without evaluating other. \returns a reference to *this.
- * \deprecated */
+ * Copies \a other into *this without evaluating other. \returns a reference to *this. */
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ /** \deprecated */
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC
Derived& lazyAssign(const DenseBase<OtherDerived>& other);
EIGEN_DEVICE_FUNC
CommaInitializer<Derived> operator<< (const Scalar& s);
- /** \deprecated it now returns \c *this */
template<unsigned int Added,unsigned int Removed>
+ /** \deprecated it now returns \c *this */
EIGEN_DEPRECATED
const Derived& flagged() const
{ return derived(); }
@@ -332,12 +337,13 @@
EIGEN_DEVICE_FUNC static const ConstantReturnType
Constant(const Scalar& value);
- EIGEN_DEVICE_FUNC static const SequentialLinSpacedReturnType
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC static const RandomAccessLinSpacedReturnType
LinSpaced(Sequential_t, Index size, const Scalar& low, const Scalar& high);
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC static const RandomAccessLinSpacedReturnType
+ LinSpaced(Sequential_t, const Scalar& low, const Scalar& high);
+
EIGEN_DEVICE_FUNC static const RandomAccessLinSpacedReturnType
LinSpaced(Index size, const Scalar& low, const Scalar& high);
- EIGEN_DEVICE_FUNC static const SequentialLinSpacedReturnType
- LinSpaced(Sequential_t, const Scalar& low, const Scalar& high);
EIGEN_DEVICE_FUNC static const RandomAccessLinSpacedReturnType
LinSpaced(const Scalar& low, const Scalar& high);
@@ -369,7 +375,7 @@
template<typename OtherDerived> EIGEN_DEVICE_FUNC
bool isApprox(const DenseBase<OtherDerived>& other,
const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
bool isMuchSmallerThan(const RealScalar& other,
const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
template<typename OtherDerived> EIGEN_DEVICE_FUNC
@@ -380,7 +386,7 @@
EIGEN_DEVICE_FUNC bool isConstant(const Scalar& value, const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
EIGEN_DEVICE_FUNC bool isZero(const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
EIGEN_DEVICE_FUNC bool isOnes(const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
-
+
inline bool hasNaN() const;
inline bool allFinite() const;
@@ -394,8 +400,8 @@
*
* Notice that in the case of a plain matrix or vector (not an expression) this function just returns
* a const reference, in order to avoid a useless copy.
- *
- * \warning Be carefull with eval() and the auto C++ keyword, as detailed in this \link TopicPitfalls_auto_keyword page \endlink.
+ *
+ * \warning Be careful with eval() and the auto C++ keyword, as detailed in this \link TopicPitfalls_auto_keyword page \endlink.
*/
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE EvalReturnType eval() const
@@ -405,12 +411,12 @@
// size types on MSVC.
return typename internal::eval<Derived>::type(derived());
}
-
+
/** swaps *this with the expression \a other.
*
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void swap(const DenseBase<OtherDerived>& other)
{
EIGEN_STATIC_ASSERT(!OtherDerived::IsPlainObjectBase,THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY);
@@ -422,7 +428,7 @@
*
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void swap(PlainObjectBase<OtherDerived>& other)
{
eigen_assert(rows()==other.rows() && cols()==other.cols());
@@ -443,18 +449,58 @@
EIGEN_DEVICE_FUNC Scalar prod() const;
+ template<int NaNPropagation>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar minCoeff() const;
+ template<int NaNPropagation>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar maxCoeff() const;
- template<typename IndexType> EIGEN_DEVICE_FUNC
+
+ // By default, the fastest version with undefined NaN propagation semantics is
+ // used.
+ // TODO(rmlarsen): Replace with default template argument when we move to
+ // c++11 or beyond.
+ EIGEN_DEVICE_FUNC inline typename internal::traits<Derived>::Scalar minCoeff() const {
+ return minCoeff<PropagateFast>();
+ }
+ EIGEN_DEVICE_FUNC inline typename internal::traits<Derived>::Scalar maxCoeff() const {
+ return maxCoeff<PropagateFast>();
+ }
+
+ template<int NaNPropagation, typename IndexType>
+ EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar minCoeff(IndexType* row, IndexType* col) const;
- template<typename IndexType> EIGEN_DEVICE_FUNC
+ template<int NaNPropagation, typename IndexType>
+ EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar maxCoeff(IndexType* row, IndexType* col) const;
- template<typename IndexType> EIGEN_DEVICE_FUNC
+ template<int NaNPropagation, typename IndexType>
+ EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar minCoeff(IndexType* index) const;
- template<typename IndexType> EIGEN_DEVICE_FUNC
+ template<int NaNPropagation, typename IndexType>
+ EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar maxCoeff(IndexType* index) const;
+ // TODO(rmlarsen): Replace these methods with a default template argument.
+ template<typename IndexType>
+ EIGEN_DEVICE_FUNC inline
+ typename internal::traits<Derived>::Scalar minCoeff(IndexType* row, IndexType* col) const {
+ return minCoeff<PropagateFast>(row, col);
+ }
+ template<typename IndexType>
+ EIGEN_DEVICE_FUNC inline
+ typename internal::traits<Derived>::Scalar maxCoeff(IndexType* row, IndexType* col) const {
+ return maxCoeff<PropagateFast>(row, col);
+ }
+ template<typename IndexType>
+ EIGEN_DEVICE_FUNC inline
+ typename internal::traits<Derived>::Scalar minCoeff(IndexType* index) const {
+ return minCoeff<PropagateFast>(index);
+ }
+ template<typename IndexType>
+ EIGEN_DEVICE_FUNC inline
+ typename internal::traits<Derived>::Scalar maxCoeff(IndexType* index) const {
+ return maxCoeff<PropagateFast>(index);
+ }
+
template<typename BinaryOp>
EIGEN_DEVICE_FUNC
Scalar redux(const BinaryOp& func) const;
@@ -493,7 +539,7 @@
typedef VectorwiseOp<Derived, Vertical> ColwiseReturnType;
typedef const VectorwiseOp<const Derived, Vertical> ConstColwiseReturnType;
- /** \returns a VectorwiseOp wrapper of *this providing additional partial reduction operations
+ /** \returns a VectorwiseOp wrapper of *this for broadcasting and partial reductions
*
* Example: \include MatrixBase_rowwise.cpp
* Output: \verbinclude MatrixBase_rowwise.out
@@ -506,7 +552,7 @@
}
EIGEN_DEVICE_FUNC RowwiseReturnType rowwise();
- /** \returns a VectorwiseOp wrapper of *this providing additional partial reduction operations
+ /** \returns a VectorwiseOp wrapper of *this broadcasting and partial reductions
*
* Example: \include MatrixBase_colwise.cpp
* Output: \verbinclude MatrixBase_colwise.out
@@ -524,16 +570,16 @@
static const RandomReturnType Random();
template<typename ThenDerived,typename ElseDerived>
- const Select<Derived,ThenDerived,ElseDerived>
+ inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived,ElseDerived>
select(const DenseBase<ThenDerived>& thenMatrix,
const DenseBase<ElseDerived>& elseMatrix) const;
template<typename ThenDerived>
- inline const Select<Derived,ThenDerived, typename ThenDerived::ConstantReturnType>
+ inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived, typename ThenDerived::ConstantReturnType>
select(const DenseBase<ThenDerived>& thenMatrix, const typename ThenDerived::Scalar& elseScalar) const;
template<typename ElseDerived>
- inline const Select<Derived, typename ElseDerived::ConstantReturnType, ElseDerived >
+ inline EIGEN_DEVICE_FUNC const Select<Derived, typename ElseDerived::ConstantReturnType, ElseDerived >
select(const typename ElseDerived::Scalar& thenScalar, const DenseBase<ElseDerived>& elseMatrix) const;
template<int p> RealScalar lpNorm() const;
@@ -567,16 +613,59 @@
}
EIGEN_DEVICE_FUNC void reverseInPlace();
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
+ /** STL-like <a href="https://en.cppreference.com/w/cpp/named_req/RandomAccessIterator">RandomAccessIterator</a>
+ * iterator type as returned by the begin() and end() methods.
+ */
+ typedef random_access_iterator_type iterator;
+ /** This is the const version of iterator (aka read-only) */
+ typedef random_access_iterator_type const_iterator;
+ #else
+ typedef typename internal::conditional< (Flags&DirectAccessBit)==DirectAccessBit,
+ internal::pointer_based_stl_iterator<Derived>,
+ internal::generic_randaccess_stl_iterator<Derived>
+ >::type iterator_type;
+
+ typedef typename internal::conditional< (Flags&DirectAccessBit)==DirectAccessBit,
+ internal::pointer_based_stl_iterator<const Derived>,
+ internal::generic_randaccess_stl_iterator<const Derived>
+ >::type const_iterator_type;
+
+ // Stl-style iterators are supported only for vectors.
+
+ typedef typename internal::conditional< IsVectorAtCompileTime,
+ iterator_type,
+ void
+ >::type iterator;
+
+ typedef typename internal::conditional< IsVectorAtCompileTime,
+ const_iterator_type,
+ void
+ >::type const_iterator;
+ #endif
+
+ inline iterator begin();
+ inline const_iterator begin() const;
+ inline const_iterator cbegin() const;
+ inline iterator end();
+ inline const_iterator end() const;
+ inline const_iterator cend() const;
+
#define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::DenseBase
#define EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
#define EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(COND)
+#define EIGEN_DOC_UNARY_ADDONS(X,Y)
+# include "../plugins/CommonCwiseUnaryOps.h"
# include "../plugins/BlockMethods.h"
+# include "../plugins/IndexedViewMethods.h"
+# include "../plugins/ReshapedMethods.h"
# ifdef EIGEN_DENSEBASE_PLUGIN
# include EIGEN_DENSEBASE_PLUGIN
# endif
#undef EIGEN_CURRENT_STORAGE_BASE_CLASS
#undef EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
#undef EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF
+#undef EIGEN_DOC_UNARY_ADDONS
// disable the use of evalTo for dense objects with a nice compilation error
template<typename Dest>
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseCoeffsBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseCoeffsBase.h
index c4af48a..37fcdb5 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseCoeffsBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseCoeffsBase.h
@@ -22,11 +22,12 @@
/** \brief Base class providing read-only coefficient access to matrices and arrays.
* \ingroup Core_Module
* \tparam Derived Type of the derived class
- * \tparam #ReadOnlyAccessors Constant indicating read-only access
+ *
+ * \note #ReadOnlyAccessors Constant indicating read-only access
*
* This class defines the \c operator() \c const function and friends, which can be used to read specific
* entries of a matrix or array.
- *
+ *
* \sa DenseCoeffsBase<Derived, WriteAccessors>, DenseCoeffsBase<Derived, DirectAccessors>,
* \ref TopicClassHierarchy
*/
@@ -288,12 +289,13 @@
/** \brief Base class providing read/write coefficient access to matrices and arrays.
* \ingroup Core_Module
* \tparam Derived Type of the derived class
- * \tparam #WriteAccessors Constant indicating read/write access
+ *
+ * \note #WriteAccessors Constant indicating read/write access
*
* This class defines the non-const \c operator() function and friends, which can be used to write specific
* entries of a matrix or array. This class inherits DenseCoeffsBase<Derived, ReadOnlyAccessors> which
* defines the const variant for reading specific entries.
- *
+ *
* \sa DenseCoeffsBase<Derived, DirectAccessors>, \ref TopicClassHierarchy
*/
template<typename Derived>
@@ -466,7 +468,8 @@
/** \brief Base class providing direct read-only coefficient access to matrices and arrays.
* \ingroup Core_Module
* \tparam Derived Type of the derived class
- * \tparam #DirectAccessors Constant indicating direct access
+ *
+ * \note #DirectAccessors Constant indicating direct access
*
* This class defines functions to work with strides which can be used to access entries directly. This class
* inherits DenseCoeffsBase<Derived, ReadOnlyAccessors> which defines functions to access entries read-only using
@@ -492,7 +495,7 @@
*
* \sa outerStride(), rowStride(), colStride()
*/
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index innerStride() const
{
return derived().innerStride();
@@ -503,14 +506,14 @@
*
* \sa innerStride(), rowStride(), colStride()
*/
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index outerStride() const
{
return derived().outerStride();
}
// FIXME shall we remove it ?
- inline Index stride() const
+ EIGEN_CONSTEXPR inline Index stride() const
{
return Derived::IsVectorAtCompileTime ? innerStride() : outerStride();
}
@@ -519,7 +522,7 @@
*
* \sa innerStride(), outerStride(), colStride()
*/
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index rowStride() const
{
return Derived::IsRowMajor ? outerStride() : innerStride();
@@ -529,7 +532,7 @@
*
* \sa innerStride(), outerStride(), rowStride()
*/
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index colStride() const
{
return Derived::IsRowMajor ? innerStride() : outerStride();
@@ -539,7 +542,8 @@
/** \brief Base class providing direct read/write coefficient access to matrices and arrays.
* \ingroup Core_Module
* \tparam Derived Type of the derived class
- * \tparam #DirectWriteAccessors Constant indicating direct access
+ *
+ * \note #DirectWriteAccessors Constant indicating direct access
*
* This class defines functions to work with strides which can be used to access entries directly. This class
* inherits DenseCoeffsBase<Derived, WriteAccessors> which defines functions to access entries read/write using
@@ -566,8 +570,8 @@
*
* \sa outerStride(), rowStride(), colStride()
*/
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT
{
return derived().innerStride();
}
@@ -577,14 +581,14 @@
*
* \sa innerStride(), rowStride(), colStride()
*/
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT
{
return derived().outerStride();
}
// FIXME shall we remove it ?
- inline Index stride() const
+ EIGEN_CONSTEXPR inline Index stride() const EIGEN_NOEXCEPT
{
return Derived::IsVectorAtCompileTime ? innerStride() : outerStride();
}
@@ -593,8 +597,8 @@
*
* \sa innerStride(), outerStride(), colStride()
*/
- EIGEN_DEVICE_FUNC
- inline Index rowStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rowStride() const EIGEN_NOEXCEPT
{
return Derived::IsRowMajor ? outerStride() : innerStride();
}
@@ -603,8 +607,8 @@
*
* \sa innerStride(), outerStride(), rowStride()
*/
- EIGEN_DEVICE_FUNC
- inline Index colStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index colStride() const EIGEN_NOEXCEPT
{
return Derived::IsRowMajor ? innerStride() : outerStride();
}
@@ -615,7 +619,7 @@
template<int Alignment, typename Derived, bool JustReturnZero>
struct first_aligned_impl
{
- static inline Index run(const Derived&)
+ static EIGEN_CONSTEXPR inline Index run(const Derived&) EIGEN_NOEXCEPT
{ return 0; }
};
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseStorage.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseStorage.h
index 7958fee..08ef6c5 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseStorage.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DenseStorage.h
@@ -47,21 +47,21 @@
EIGEN_DEVICE_FUNC
plain_array()
- {
+ {
check_static_allocation_size<T,Size>();
}
EIGEN_DEVICE_FUNC
plain_array(constructor_without_unaligned_array_assert)
- {
+ {
check_static_allocation_size<T,Size>();
}
};
#if defined(EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT)
#define EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(sizemask)
-#elif EIGEN_GNUC_AT_LEAST(4,7)
- // GCC 4.7 is too aggressive in its optimizations and remove the alignement test based on the fact the array is declared to be aligned.
+#elif EIGEN_GNUC_AT_LEAST(4,7)
+ // GCC 4.7 is too aggressive in its optimizations and remove the alignment test based on the fact the array is declared to be aligned.
// See this bug report: http://gcc.gnu.org/bugzilla/show_bug.cgi?id=53900
// Hiding the origin of the array pointer behind a function argument seems to do the trick even if the function is inlined:
template<typename PtrType>
@@ -85,15 +85,15 @@
EIGEN_ALIGN_TO_BOUNDARY(8) T array[Size];
EIGEN_DEVICE_FUNC
- plain_array()
+ plain_array()
{
EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(7);
check_static_allocation_size<T,Size>();
}
EIGEN_DEVICE_FUNC
- plain_array(constructor_without_unaligned_array_assert)
- {
+ plain_array(constructor_without_unaligned_array_assert)
+ {
check_static_allocation_size<T,Size>();
}
};
@@ -104,15 +104,15 @@
EIGEN_ALIGN_TO_BOUNDARY(16) T array[Size];
EIGEN_DEVICE_FUNC
- plain_array()
- {
+ plain_array()
+ {
EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(15);
check_static_allocation_size<T,Size>();
}
EIGEN_DEVICE_FUNC
- plain_array(constructor_without_unaligned_array_assert)
- {
+ plain_array(constructor_without_unaligned_array_assert)
+ {
check_static_allocation_size<T,Size>();
}
};
@@ -123,15 +123,15 @@
EIGEN_ALIGN_TO_BOUNDARY(32) T array[Size];
EIGEN_DEVICE_FUNC
- plain_array()
+ plain_array()
{
EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(31);
check_static_allocation_size<T,Size>();
}
EIGEN_DEVICE_FUNC
- plain_array(constructor_without_unaligned_array_assert)
- {
+ plain_array(constructor_without_unaligned_array_assert)
+ {
check_static_allocation_size<T,Size>();
}
};
@@ -142,15 +142,15 @@
EIGEN_ALIGN_TO_BOUNDARY(64) T array[Size];
EIGEN_DEVICE_FUNC
- plain_array()
- {
+ plain_array()
+ {
EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(63);
check_static_allocation_size<T,Size>();
}
EIGEN_DEVICE_FUNC
- plain_array(constructor_without_unaligned_array_assert)
- {
+ plain_array(constructor_without_unaligned_array_assert)
+ {
check_static_allocation_size<T,Size>();
}
};
@@ -163,6 +163,30 @@
EIGEN_DEVICE_FUNC plain_array(constructor_without_unaligned_array_assert) {}
};
+struct plain_array_helper {
+ template<typename T, int Size, int MatrixOrArrayOptions, int Alignment>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ static void copy(const plain_array<T, Size, MatrixOrArrayOptions, Alignment>& src, const Eigen::Index size,
+ plain_array<T, Size, MatrixOrArrayOptions, Alignment>& dst) {
+ smart_copy(src.array, src.array + size, dst.array);
+ }
+
+ template<typename T, int Size, int MatrixOrArrayOptions, int Alignment>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ static void swap(plain_array<T, Size, MatrixOrArrayOptions, Alignment>& a, const Eigen::Index a_size,
+ plain_array<T, Size, MatrixOrArrayOptions, Alignment>& b, const Eigen::Index b_size) {
+ if (a_size < b_size) {
+ std::swap_ranges(b.array, b.array + a_size, a.array);
+ smart_move(b.array + a_size, b.array + b_size, a.array + a_size);
+ } else if (a_size > b_size) {
+ std::swap_ranges(a.array, a.array + b_size, b.array);
+ smart_move(a.array + b_size, a.array + a_size, b.array + b_size);
+ } else {
+ std::swap_ranges(a.array, a.array + a_size, b.array);
+ }
+ }
+};
+
} // end namespace internal
/** \internal
@@ -190,16 +214,41 @@
EIGEN_DEVICE_FUNC
explicit DenseStorage(internal::constructor_without_unaligned_array_assert)
: m_data(internal::constructor_without_unaligned_array_assert()) {}
- EIGEN_DEVICE_FUNC
+#if !EIGEN_HAS_CXX11 || defined(EIGEN_DENSE_STORAGE_CTOR_PLUGIN)
+ EIGEN_DEVICE_FUNC
DenseStorage(const DenseStorage& other) : m_data(other.m_data) {
EIGEN_INTERNAL_DENSE_STORAGE_CTOR_PLUGIN(Index size = Size)
}
- EIGEN_DEVICE_FUNC
+#else
+ EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage&) = default;
+#endif
+#if !EIGEN_HAS_CXX11
+ EIGEN_DEVICE_FUNC
DenseStorage& operator=(const DenseStorage& other)
- {
+ {
if (this != &other) m_data = other.m_data;
- return *this;
+ return *this;
}
+#else
+ EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage&) = default;
+#endif
+#if EIGEN_HAS_RVALUE_REFERENCES
+#if !EIGEN_HAS_CXX11
+ EIGEN_DEVICE_FUNC DenseStorage(DenseStorage&& other) EIGEN_NOEXCEPT
+ : m_data(std::move(other.m_data))
+ {
+ }
+ EIGEN_DEVICE_FUNC DenseStorage& operator=(DenseStorage&& other) EIGEN_NOEXCEPT
+ {
+ if (this != &other)
+ m_data = std::move(other.m_data);
+ return *this;
+ }
+#else
+ EIGEN_DEVICE_FUNC DenseStorage(DenseStorage&&) = default;
+ EIGEN_DEVICE_FUNC DenseStorage& operator=(DenseStorage&&) = default;
+#endif
+#endif
EIGEN_DEVICE_FUNC DenseStorage(Index size, Index rows, Index cols) {
EIGEN_INTERNAL_DENSE_STORAGE_CTOR_PLUGIN({})
eigen_internal_assert(size==rows*cols && rows==_Rows && cols==_Cols);
@@ -207,9 +256,11 @@
EIGEN_UNUSED_VARIABLE(rows);
EIGEN_UNUSED_VARIABLE(cols);
}
- EIGEN_DEVICE_FUNC void swap(DenseStorage& other) { std::swap(m_data,other.m_data); }
- EIGEN_DEVICE_FUNC static Index rows(void) {return _Rows;}
- EIGEN_DEVICE_FUNC static Index cols(void) {return _Cols;}
+ EIGEN_DEVICE_FUNC void swap(DenseStorage& other) {
+ numext::swap(m_data, other.m_data);
+ }
+ EIGEN_DEVICE_FUNC static EIGEN_CONSTEXPR Index rows(void) EIGEN_NOEXCEPT {return _Rows;}
+ EIGEN_DEVICE_FUNC static EIGEN_CONSTEXPR Index cols(void) EIGEN_NOEXCEPT {return _Cols;}
EIGEN_DEVICE_FUNC void conservativeResize(Index,Index,Index) {}
EIGEN_DEVICE_FUNC void resize(Index,Index,Index) {}
EIGEN_DEVICE_FUNC const T *data() const { return m_data.array; }
@@ -226,8 +277,8 @@
EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage&) { return *this; }
EIGEN_DEVICE_FUNC DenseStorage(Index,Index,Index) {}
EIGEN_DEVICE_FUNC void swap(DenseStorage& ) {}
- EIGEN_DEVICE_FUNC static Index rows(void) {return _Rows;}
- EIGEN_DEVICE_FUNC static Index cols(void) {return _Cols;}
+ EIGEN_DEVICE_FUNC static EIGEN_CONSTEXPR Index rows(void) EIGEN_NOEXCEPT {return _Rows;}
+ EIGEN_DEVICE_FUNC static EIGEN_CONSTEXPR Index cols(void) EIGEN_NOEXCEPT {return _Cols;}
EIGEN_DEVICE_FUNC void conservativeResize(Index,Index,Index) {}
EIGEN_DEVICE_FUNC void resize(Index,Index,Index) {}
EIGEN_DEVICE_FUNC const T *data() const { return 0; }
@@ -254,20 +305,28 @@
EIGEN_DEVICE_FUNC DenseStorage() : m_rows(0), m_cols(0) {}
EIGEN_DEVICE_FUNC explicit DenseStorage(internal::constructor_without_unaligned_array_assert)
: m_data(internal::constructor_without_unaligned_array_assert()), m_rows(0), m_cols(0) {}
- EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage& other) : m_data(other.m_data), m_rows(other.m_rows), m_cols(other.m_cols) {}
- EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage& other)
- {
+ EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage& other)
+ : m_data(internal::constructor_without_unaligned_array_assert()), m_rows(other.m_rows), m_cols(other.m_cols)
+ {
+ internal::plain_array_helper::copy(other.m_data, m_rows * m_cols, m_data);
+ }
+ EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage& other)
+ {
if (this != &other)
{
- m_data = other.m_data;
m_rows = other.m_rows;
m_cols = other.m_cols;
+ internal::plain_array_helper::copy(other.m_data, m_rows * m_cols, m_data);
}
- return *this;
+ return *this;
}
EIGEN_DEVICE_FUNC DenseStorage(Index, Index rows, Index cols) : m_rows(rows), m_cols(cols) {}
EIGEN_DEVICE_FUNC void swap(DenseStorage& other)
- { std::swap(m_data,other.m_data); std::swap(m_rows,other.m_rows); std::swap(m_cols,other.m_cols); }
+ {
+ internal::plain_array_helper::swap(m_data, m_rows * m_cols, other.m_data, other.m_rows * other.m_cols);
+ numext::swap(m_rows,other.m_rows);
+ numext::swap(m_cols,other.m_cols);
+ }
EIGEN_DEVICE_FUNC Index rows() const {return m_rows;}
EIGEN_DEVICE_FUNC Index cols() const {return m_cols;}
EIGEN_DEVICE_FUNC void conservativeResize(Index, Index rows, Index cols) { m_rows = rows; m_cols = cols; }
@@ -285,20 +344,29 @@
EIGEN_DEVICE_FUNC DenseStorage() : m_rows(0) {}
EIGEN_DEVICE_FUNC explicit DenseStorage(internal::constructor_without_unaligned_array_assert)
: m_data(internal::constructor_without_unaligned_array_assert()), m_rows(0) {}
- EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage& other) : m_data(other.m_data), m_rows(other.m_rows) {}
- EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage& other)
+ EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage& other)
+ : m_data(internal::constructor_without_unaligned_array_assert()), m_rows(other.m_rows)
+ {
+ internal::plain_array_helper::copy(other.m_data, m_rows * _Cols, m_data);
+ }
+
+ EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage& other)
{
if (this != &other)
{
- m_data = other.m_data;
m_rows = other.m_rows;
+ internal::plain_array_helper::copy(other.m_data, m_rows * _Cols, m_data);
}
- return *this;
+ return *this;
}
EIGEN_DEVICE_FUNC DenseStorage(Index, Index rows, Index) : m_rows(rows) {}
- EIGEN_DEVICE_FUNC void swap(DenseStorage& other) { std::swap(m_data,other.m_data); std::swap(m_rows,other.m_rows); }
- EIGEN_DEVICE_FUNC Index rows(void) const {return m_rows;}
- EIGEN_DEVICE_FUNC Index cols(void) const {return _Cols;}
+ EIGEN_DEVICE_FUNC void swap(DenseStorage& other)
+ {
+ internal::plain_array_helper::swap(m_data, m_rows * _Cols, other.m_data, other.m_rows * _Cols);
+ numext::swap(m_rows, other.m_rows);
+ }
+ EIGEN_DEVICE_FUNC Index rows(void) const EIGEN_NOEXCEPT {return m_rows;}
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols(void) const EIGEN_NOEXCEPT {return _Cols;}
EIGEN_DEVICE_FUNC void conservativeResize(Index, Index rows, Index) { m_rows = rows; }
EIGEN_DEVICE_FUNC void resize(Index, Index rows, Index) { m_rows = rows; }
EIGEN_DEVICE_FUNC const T *data() const { return m_data.array; }
@@ -314,22 +382,29 @@
EIGEN_DEVICE_FUNC DenseStorage() : m_cols(0) {}
EIGEN_DEVICE_FUNC explicit DenseStorage(internal::constructor_without_unaligned_array_assert)
: m_data(internal::constructor_without_unaligned_array_assert()), m_cols(0) {}
- EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage& other) : m_data(other.m_data), m_cols(other.m_cols) {}
+ EIGEN_DEVICE_FUNC DenseStorage(const DenseStorage& other)
+ : m_data(internal::constructor_without_unaligned_array_assert()), m_cols(other.m_cols)
+ {
+ internal::plain_array_helper::copy(other.m_data, _Rows * m_cols, m_data);
+ }
EIGEN_DEVICE_FUNC DenseStorage& operator=(const DenseStorage& other)
{
if (this != &other)
{
- m_data = other.m_data;
m_cols = other.m_cols;
+ internal::plain_array_helper::copy(other.m_data, _Rows * m_cols, m_data);
}
return *this;
}
EIGEN_DEVICE_FUNC DenseStorage(Index, Index, Index cols) : m_cols(cols) {}
- EIGEN_DEVICE_FUNC void swap(DenseStorage& other) { std::swap(m_data,other.m_data); std::swap(m_cols,other.m_cols); }
- EIGEN_DEVICE_FUNC Index rows(void) const {return _Rows;}
- EIGEN_DEVICE_FUNC Index cols(void) const {return m_cols;}
- void conservativeResize(Index, Index, Index cols) { m_cols = cols; }
- void resize(Index, Index, Index cols) { m_cols = cols; }
+ EIGEN_DEVICE_FUNC void swap(DenseStorage& other) {
+ internal::plain_array_helper::swap(m_data, _Rows * m_cols, other.m_data, _Rows * other.m_cols);
+ numext::swap(m_cols, other.m_cols);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows(void) const EIGEN_NOEXCEPT {return _Rows;}
+ EIGEN_DEVICE_FUNC Index cols(void) const EIGEN_NOEXCEPT {return m_cols;}
+ EIGEN_DEVICE_FUNC void conservativeResize(Index, Index, Index cols) { m_cols = cols; }
+ EIGEN_DEVICE_FUNC void resize(Index, Index, Index cols) { m_cols = cols; }
EIGEN_DEVICE_FUNC const T *data() const { return m_data.array; }
EIGEN_DEVICE_FUNC T *data() { return m_data.array; }
};
@@ -381,18 +456,21 @@
EIGEN_DEVICE_FUNC
DenseStorage& operator=(DenseStorage&& other) EIGEN_NOEXCEPT
{
- using std::swap;
- swap(m_data, other.m_data);
- swap(m_rows, other.m_rows);
- swap(m_cols, other.m_cols);
+ numext::swap(m_data, other.m_data);
+ numext::swap(m_rows, other.m_rows);
+ numext::swap(m_cols, other.m_cols);
return *this;
}
#endif
EIGEN_DEVICE_FUNC ~DenseStorage() { internal::conditional_aligned_delete_auto<T,(_Options&DontAlign)==0>(m_data, m_rows*m_cols); }
EIGEN_DEVICE_FUNC void swap(DenseStorage& other)
- { std::swap(m_data,other.m_data); std::swap(m_rows,other.m_rows); std::swap(m_cols,other.m_cols); }
- EIGEN_DEVICE_FUNC Index rows(void) const {return m_rows;}
- EIGEN_DEVICE_FUNC Index cols(void) const {return m_cols;}
+ {
+ numext::swap(m_data,other.m_data);
+ numext::swap(m_rows,other.m_rows);
+ numext::swap(m_cols,other.m_cols);
+ }
+ EIGEN_DEVICE_FUNC Index rows(void) const EIGEN_NOEXCEPT {return m_rows;}
+ EIGEN_DEVICE_FUNC Index cols(void) const EIGEN_NOEXCEPT {return m_cols;}
void conservativeResize(Index size, Index rows, Index cols)
{
m_data = internal::conditional_aligned_realloc_new_auto<T,(_Options&DontAlign)==0>(m_data, size, m_rows*m_cols);
@@ -404,7 +482,7 @@
if(size != m_rows*m_cols)
{
internal::conditional_aligned_delete_auto<T,(_Options&DontAlign)==0>(m_data, m_rows*m_cols);
- if (size)
+ if (size>0) // >0 and not simply !=0 to let the compiler knows that size cannot be negative
m_data = internal::conditional_aligned_new_auto<T,(_Options&DontAlign)==0>(size);
else
m_data = 0;
@@ -446,7 +524,7 @@
this->swap(tmp);
}
return *this;
- }
+ }
#if EIGEN_HAS_RVALUE_REFERENCES
EIGEN_DEVICE_FUNC
DenseStorage(DenseStorage&& other) EIGEN_NOEXCEPT
@@ -459,16 +537,18 @@
EIGEN_DEVICE_FUNC
DenseStorage& operator=(DenseStorage&& other) EIGEN_NOEXCEPT
{
- using std::swap;
- swap(m_data, other.m_data);
- swap(m_cols, other.m_cols);
+ numext::swap(m_data, other.m_data);
+ numext::swap(m_cols, other.m_cols);
return *this;
}
#endif
EIGEN_DEVICE_FUNC ~DenseStorage() { internal::conditional_aligned_delete_auto<T,(_Options&DontAlign)==0>(m_data, _Rows*m_cols); }
- EIGEN_DEVICE_FUNC void swap(DenseStorage& other) { std::swap(m_data,other.m_data); std::swap(m_cols,other.m_cols); }
- EIGEN_DEVICE_FUNC static Index rows(void) {return _Rows;}
- EIGEN_DEVICE_FUNC Index cols(void) const {return m_cols;}
+ EIGEN_DEVICE_FUNC void swap(DenseStorage& other) {
+ numext::swap(m_data,other.m_data);
+ numext::swap(m_cols,other.m_cols);
+ }
+ EIGEN_DEVICE_FUNC static EIGEN_CONSTEXPR Index rows(void) EIGEN_NOEXCEPT {return _Rows;}
+ EIGEN_DEVICE_FUNC Index cols(void) const EIGEN_NOEXCEPT {return m_cols;}
EIGEN_DEVICE_FUNC void conservativeResize(Index size, Index, Index cols)
{
m_data = internal::conditional_aligned_realloc_new_auto<T,(_Options&DontAlign)==0>(m_data, size, _Rows*m_cols);
@@ -479,7 +559,7 @@
if(size != _Rows*m_cols)
{
internal::conditional_aligned_delete_auto<T,(_Options&DontAlign)==0>(m_data, _Rows*m_cols);
- if (size)
+ if (size>0) // >0 and not simply !=0 to let the compiler knows that size cannot be negative
m_data = internal::conditional_aligned_new_auto<T,(_Options&DontAlign)==0>(size);
else
m_data = 0;
@@ -520,7 +600,7 @@
this->swap(tmp);
}
return *this;
- }
+ }
#if EIGEN_HAS_RVALUE_REFERENCES
EIGEN_DEVICE_FUNC
DenseStorage(DenseStorage&& other) EIGEN_NOEXCEPT
@@ -533,16 +613,18 @@
EIGEN_DEVICE_FUNC
DenseStorage& operator=(DenseStorage&& other) EIGEN_NOEXCEPT
{
- using std::swap;
- swap(m_data, other.m_data);
- swap(m_rows, other.m_rows);
+ numext::swap(m_data, other.m_data);
+ numext::swap(m_rows, other.m_rows);
return *this;
}
#endif
EIGEN_DEVICE_FUNC ~DenseStorage() { internal::conditional_aligned_delete_auto<T,(_Options&DontAlign)==0>(m_data, _Cols*m_rows); }
- EIGEN_DEVICE_FUNC void swap(DenseStorage& other) { std::swap(m_data,other.m_data); std::swap(m_rows,other.m_rows); }
- EIGEN_DEVICE_FUNC Index rows(void) const {return m_rows;}
- EIGEN_DEVICE_FUNC static Index cols(void) {return _Cols;}
+ EIGEN_DEVICE_FUNC void swap(DenseStorage& other) {
+ numext::swap(m_data,other.m_data);
+ numext::swap(m_rows,other.m_rows);
+ }
+ EIGEN_DEVICE_FUNC Index rows(void) const EIGEN_NOEXCEPT {return m_rows;}
+ EIGEN_DEVICE_FUNC static EIGEN_CONSTEXPR Index cols(void) {return _Cols;}
void conservativeResize(Index size, Index rows, Index)
{
m_data = internal::conditional_aligned_realloc_new_auto<T,(_Options&DontAlign)==0>(m_data, size, m_rows*_Cols);
@@ -553,7 +635,7 @@
if(size != m_rows*_Cols)
{
internal::conditional_aligned_delete_auto<T,(_Options&DontAlign)==0>(m_data, _Cols*m_rows);
- if (size)
+ if (size>0) // >0 and not simply !=0 to let the compiler knows that size cannot be negative
m_data = internal::conditional_aligned_new_auto<T,(_Options&DontAlign)==0>(size);
else
m_data = 0;
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Diagonal.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Diagonal.h
index afcaf35..3112d2c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Diagonal.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Diagonal.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_DIAGONAL_H
#define EIGEN_DIAGONAL_H
-namespace Eigen {
+namespace Eigen {
/** \class Diagonal
* \ingroup Core_Module
@@ -84,20 +84,16 @@
: numext::mini<Index>(m_matrix.rows(),m_matrix.cols()-m_index.value());
}
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return 1; }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return 1; }
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const
- {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT {
return m_matrix.outerStride() + 1;
}
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const
- {
- return 0;
- }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return 0; }
typedef typename internal::conditional<
internal::is_lvalue<MatrixType>::value,
@@ -149,8 +145,8 @@
}
EIGEN_DEVICE_FUNC
- inline const typename internal::remove_all<typename MatrixType::Nested>::type&
- nestedExpression() const
+ inline const typename internal::remove_all<typename MatrixType::Nested>::type&
+ nestedExpression() const
{
return m_matrix;
}
@@ -167,12 +163,12 @@
private:
// some compilers may fail to optimize std::max etc in case of compile-time constants...
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index absDiagIndex() const { return m_index.value()>0 ? m_index.value() : -m_index.value(); }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index rowOffset() const { return m_index.value()>0 ? 0 : -m_index.value(); }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index colOffset() const { return m_index.value()>0 ? m_index.value() : 0; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index absDiagIndex() const EIGEN_NOEXCEPT { return m_index.value()>0 ? m_index.value() : -m_index.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rowOffset() const EIGEN_NOEXCEPT { return m_index.value()>0 ? 0 : -m_index.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index colOffset() const EIGEN_NOEXCEPT { return m_index.value()>0 ? m_index.value() : 0; }
// trigger a compile-time error if someone try to call packet
template<int LoadMode> typename MatrixType::PacketReturnType packet(Index) const;
template<int LoadMode> typename MatrixType::PacketReturnType packet(Index,Index) const;
@@ -187,7 +183,7 @@
*
* \sa class Diagonal */
template<typename Derived>
-inline typename MatrixBase<Derived>::DiagonalReturnType
+EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::DiagonalReturnType
MatrixBase<Derived>::diagonal()
{
return DiagonalReturnType(derived());
@@ -195,7 +191,7 @@
/** This is the const version of diagonal(). */
template<typename Derived>
-inline typename MatrixBase<Derived>::ConstDiagonalReturnType
+EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::ConstDiagonalReturnType
MatrixBase<Derived>::diagonal() const
{
return ConstDiagonalReturnType(derived());
@@ -213,7 +209,7 @@
*
* \sa MatrixBase::diagonal(), class Diagonal */
template<typename Derived>
-inline typename MatrixBase<Derived>::DiagonalDynamicIndexReturnType
+EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::DiagonalDynamicIndexReturnType
MatrixBase<Derived>::diagonal(Index index)
{
return DiagonalDynamicIndexReturnType(derived(), index);
@@ -221,7 +217,7 @@
/** This is the const version of diagonal(Index). */
template<typename Derived>
-inline typename MatrixBase<Derived>::ConstDiagonalDynamicIndexReturnType
+EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::ConstDiagonalDynamicIndexReturnType
MatrixBase<Derived>::diagonal(Index index) const
{
return ConstDiagonalDynamicIndexReturnType(derived(), index);
@@ -240,6 +236,7 @@
* \sa MatrixBase::diagonal(), class Diagonal */
template<typename Derived>
template<int Index_>
+EIGEN_DEVICE_FUNC
inline typename MatrixBase<Derived>::template DiagonalIndexReturnType<Index_>::Type
MatrixBase<Derived>::diagonal()
{
@@ -249,6 +246,7 @@
/** This is the const version of diagonal<int>(). */
template<typename Derived>
template<int Index_>
+EIGEN_DEVICE_FUNC
inline typename MatrixBase<Derived>::template ConstDiagonalIndexReturnType<Index_>::Type
MatrixBase<Derived>::diagonal() const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalMatrix.h
index ecfdce8..542685c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalMatrix.h
@@ -44,7 +44,7 @@
EIGEN_DEVICE_FUNC
DenseMatrixType toDenseMatrix() const { return derived(); }
-
+
EIGEN_DEVICE_FUNC
inline const DiagonalVectorType& diagonal() const { return derived().diagonal(); }
EIGEN_DEVICE_FUNC
@@ -83,6 +83,30 @@
{
return DiagonalWrapper<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,DiagonalVectorType,product) >(scalar * other.diagonal());
}
+
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
+ inline unspecified_expression_type
+ #else
+ inline const DiagonalWrapper<const EIGEN_CWISE_BINARY_RETURN_TYPE(DiagonalVectorType,typename OtherDerived::DiagonalVectorType,sum) >
+ #endif
+ operator+(const DiagonalBase<OtherDerived>& other) const
+ {
+ return (diagonal() + other.diagonal()).asDiagonal();
+ }
+
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
+ inline unspecified_expression_type
+ #else
+ inline const DiagonalWrapper<const EIGEN_CWISE_BINARY_RETURN_TYPE(DiagonalVectorType,typename OtherDerived::DiagonalVectorType,difference) >
+ #endif
+ operator-(const DiagonalBase<OtherDerived>& other) const
+ {
+ return (diagonal() - other.diagonal()).asDiagonal();
+ }
};
#endif
@@ -154,6 +178,30 @@
EIGEN_DEVICE_FUNC
inline DiagonalMatrix(const Scalar& x, const Scalar& y, const Scalar& z) : m_diagonal(x,y,z) {}
+ #if EIGEN_HAS_CXX11
+ /** \brief Construct a diagonal matrix with fixed size from an arbitrary number of coefficients. \cpp11
+ *
+ * There exists C++98 anologue constructors for fixed-size diagonal matrices having 2 or 3 coefficients.
+ *
+ * \warning To construct a diagonal matrix of fixed size, the number of values passed to this
+ * constructor must match the fixed dimension of \c *this.
+ *
+ * \sa DiagonalMatrix(const Scalar&, const Scalar&)
+ * \sa DiagonalMatrix(const Scalar&, const Scalar&, const Scalar&)
+ */
+ template <typename... ArgTypes>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ DiagonalMatrix(const Scalar& a0, const Scalar& a1, const Scalar& a2, const ArgTypes&... args)
+ : m_diagonal(a0, a1, a2, args...) {}
+
+ /** \brief Constructs a DiagonalMatrix and initializes it by elements given by an initializer list of initializer
+ * lists \cpp11
+ */
+ EIGEN_DEVICE_FUNC
+ explicit EIGEN_STRONG_INLINE DiagonalMatrix(const std::initializer_list<std::initializer_list<Scalar>>& list)
+ : m_diagonal(list) {}
+ #endif // EIGEN_HAS_CXX11
+
/** Copy constructor. */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
@@ -273,7 +321,7 @@
* \sa class DiagonalWrapper, class DiagonalMatrix, diagonal(), isDiagonal()
**/
template<typename Derived>
-inline const DiagonalWrapper<const Derived>
+EIGEN_DEVICE_FUNC inline const DiagonalWrapper<const Derived>
MatrixBase<Derived>::asDiagonal() const
{
return DiagonalWrapper<const Derived>(derived());
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalProduct.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalProduct.h
index d372b93..7911d1c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalProduct.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/DiagonalProduct.h
@@ -17,7 +17,7 @@
*/
template<typename Derived>
template<typename DiagonalDerived>
-inline const Product<Derived, DiagonalDerived, LazyProduct>
+EIGEN_DEVICE_FUNC inline const Product<Derived, DiagonalDerived, LazyProduct>
MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &a_diagonal) const
{
return Product<Derived, DiagonalDerived, LazyProduct>(derived(),a_diagonal.derived());
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Dot.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Dot.h
index 1fe7a84..5c3441b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Dot.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Dot.h
@@ -86,14 +86,14 @@
//---------- implementation of L2 norm and related functions ----------
-/** \returns, for vectors, the squared \em l2 norm of \c *this, and for matrices the Frobenius norm.
+/** \returns, for vectors, the squared \em l2 norm of \c *this, and for matrices the squared Frobenius norm.
* In both cases, it consists in the sum of the square of all the matrix entries.
* For vectors, this is also equals to the dot product of \c *this with itself.
*
* \sa dot(), norm(), lpNorm()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename NumTraits<typename internal::traits<Derived>::Scalar>::Real MatrixBase<Derived>::squaredNorm() const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename NumTraits<typename internal::traits<Derived>::Scalar>::Real MatrixBase<Derived>::squaredNorm() const
{
return numext::real((*this).cwiseAbs2().sum());
}
@@ -105,7 +105,7 @@
* \sa lpNorm(), dot(), squaredNorm()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename NumTraits<typename internal::traits<Derived>::Scalar>::Real MatrixBase<Derived>::norm() const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename NumTraits<typename internal::traits<Derived>::Scalar>::Real MatrixBase<Derived>::norm() const
{
return numext::sqrt(squaredNorm());
}
@@ -120,7 +120,7 @@
* \sa norm(), normalize()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE const typename MatrixBase<Derived>::PlainObject
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename MatrixBase<Derived>::PlainObject
MatrixBase<Derived>::normalized() const
{
typedef typename internal::nested_eval<Derived,2>::type _Nested;
@@ -142,7 +142,7 @@
* \sa norm(), normalized()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE void MatrixBase<Derived>::normalize()
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void MatrixBase<Derived>::normalize()
{
RealScalar z = squaredNorm();
// NOTE: after extensive benchmarking, this conditional does not impact performance, at least on recent x86 CPU
@@ -163,7 +163,7 @@
* \sa stableNorm(), stableNormalize(), normalized()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE const typename MatrixBase<Derived>::PlainObject
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename MatrixBase<Derived>::PlainObject
MatrixBase<Derived>::stableNormalized() const
{
typedef typename internal::nested_eval<Derived,3>::type _Nested;
@@ -188,7 +188,7 @@
* \sa stableNorm(), stableNormalized(), normalize()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE void MatrixBase<Derived>::stableNormalize()
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void MatrixBase<Derived>::stableNormalize()
{
RealScalar w = cwiseAbs().maxCoeff();
RealScalar z = (derived()/w).squaredNorm();
@@ -207,7 +207,7 @@
EIGEN_DEVICE_FUNC
static inline RealScalar run(const MatrixBase<Derived>& m)
{
- EIGEN_USING_STD_MATH(pow)
+ EIGEN_USING_STD(pow)
return pow(m.cwiseAbs().array().pow(p).sum(), RealScalar(1)/p);
}
};
@@ -260,9 +260,9 @@
template<typename Derived>
template<int p>
#ifndef EIGEN_PARSED_BY_DOXYGEN
-inline typename NumTraits<typename internal::traits<Derived>::Scalar>::Real
+EIGEN_DEVICE_FUNC inline typename NumTraits<typename internal::traits<Derived>::Scalar>::Real
#else
-MatrixBase<Derived>::RealScalar
+EIGEN_DEVICE_FUNC MatrixBase<Derived>::RealScalar
#endif
MatrixBase<Derived>::lpNorm() const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/EigenBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/EigenBase.h
index b195506..6b3c7d3 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/EigenBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/EigenBase.h
@@ -15,7 +15,7 @@
/** \class EigenBase
* \ingroup Core_Module
- *
+ *
* Common base class for all classes T such that MatrixBase has an operator=(T) and a constructor MatrixBase(T).
*
* In other words, an EigenBase object is an object that can be copied into a MatrixBase.
@@ -29,11 +29,12 @@
template<typename Derived> struct EigenBase
{
// typedef typename internal::plain_matrix_type<Derived>::type PlainObject;
-
+
/** \brief The interface type of indices
* \details To change this, \c \#define the preprocessor symbol \c EIGEN_DEFAULT_DENSE_INDEX_TYPE.
- * \deprecated Since Eigen 3.3, its usage is deprecated. Use Eigen::Index instead.
* \sa StorageIndex, \ref TopicPreprocessorDirectives.
+ * DEPRECATED: Since Eigen 3.3, its usage is deprecated. Use Eigen::Index instead.
+ * Deprecation is not marked with a doxygen comment because there are too many existing usages to add the deprecation attribute.
*/
typedef Eigen::Index Index;
@@ -55,15 +56,15 @@
{ return *static_cast<const Derived*>(this); }
/** \returns the number of rows. \sa cols(), RowsAtCompileTime */
- EIGEN_DEVICE_FUNC
- inline Index rows() const { return derived().rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return derived().rows(); }
/** \returns the number of columns. \sa rows(), ColsAtCompileTime*/
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return derived().cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return derived().cols(); }
/** \returns the number of coefficients, which is rows()*cols().
* \sa rows(), cols(), SizeAtCompileTime. */
- EIGEN_DEVICE_FUNC
- inline Index size() const { return rows() * cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index size() const EIGEN_NOEXCEPT { return rows() * cols(); }
/** \internal Don't use it, but do the equivalent: \code dst = *this; \endcode */
template<typename Dest>
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ForceAlignedAccess.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ForceAlignedAccess.h
index 7b08b45..817a43a 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ForceAlignedAccess.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ForceAlignedAccess.h
@@ -41,10 +41,14 @@
EIGEN_DEVICE_FUNC explicit inline ForceAlignedAccess(const ExpressionType& matrix) : m_expression(matrix) {}
- EIGEN_DEVICE_FUNC inline Index rows() const { return m_expression.rows(); }
- EIGEN_DEVICE_FUNC inline Index cols() const { return m_expression.cols(); }
- EIGEN_DEVICE_FUNC inline Index outerStride() const { return m_expression.outerStride(); }
- EIGEN_DEVICE_FUNC inline Index innerStride() const { return m_expression.innerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_expression.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_expression.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return m_expression.outerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC inline const CoeffReturnType coeff(Index row, Index col) const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Fuzzy.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Fuzzy.h
index 3e403a0..43aa49b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Fuzzy.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Fuzzy.h
@@ -100,7 +100,7 @@
*/
template<typename Derived>
template<typename OtherDerived>
-bool DenseBase<Derived>::isApprox(
+EIGEN_DEVICE_FUNC bool DenseBase<Derived>::isApprox(
const DenseBase<OtherDerived>& other,
const RealScalar& prec
) const
@@ -122,7 +122,7 @@
* \sa isApprox(), isMuchSmallerThan(const DenseBase<OtherDerived>&, RealScalar) const
*/
template<typename Derived>
-bool DenseBase<Derived>::isMuchSmallerThan(
+EIGEN_DEVICE_FUNC bool DenseBase<Derived>::isMuchSmallerThan(
const typename NumTraits<Scalar>::Real& other,
const RealScalar& prec
) const
@@ -142,7 +142,7 @@
*/
template<typename Derived>
template<typename OtherDerived>
-bool DenseBase<Derived>::isMuchSmallerThan(
+EIGEN_DEVICE_FUNC bool DenseBase<Derived>::isMuchSmallerThan(
const DenseBase<OtherDerived>& other,
const RealScalar& prec
) const
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GeneralProduct.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GeneralProduct.h
index 6f0cc80..6906aa7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GeneralProduct.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GeneralProduct.h
@@ -18,6 +18,16 @@
Small = 3
};
+// Define the threshold value to fallback from the generic matrix-matrix product
+// implementation (heavy) to the lightweight coeff-based product one.
+// See generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
+// in products/GeneralMatrixMatrix.h for more details.
+// TODO This threshold should also be used in the compile-time selector below.
+#ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
+// This default value has been obtained on a Haswell architecture.
+#define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20
+#endif
+
namespace internal {
template<int Rows, int Cols, int Depth> struct product_type_selector;
@@ -25,7 +35,7 @@
template<int Size, int MaxSize> struct product_size_category
{
enum {
- #ifndef EIGEN_CUDA_ARCH
+ #ifndef EIGEN_GPU_COMPILE_PHASE
is_large = MaxSize == Dynamic ||
Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD ||
(Size==Dynamic && MaxSize>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD),
@@ -153,13 +163,13 @@
template<typename Scalar,int Size,int MaxSize>
struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
{
- EIGEN_STRONG_INLINE Scalar* data() { eigen_internal_assert(false && "should never be called"); return 0; }
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { eigen_internal_assert(false && "should never be called"); return 0; }
};
template<typename Scalar,int Size>
struct gemv_static_vector_if<Scalar,Size,Dynamic,true>
{
- EIGEN_STRONG_INLINE Scalar* data() { return 0; }
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { return 0; }
};
template<typename Scalar,int Size,int MaxSize>
@@ -218,8 +228,7 @@
ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
- ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
- * RhsBlasTraits::extractScalarFactor(rhs);
+ ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
// make sure Dest is a compile-time vector type (bug 1166)
typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
@@ -229,7 +238,7 @@
// on, the other hand it is good for the cache to pack the vector anyways...
EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
- MightCannotUseDest = (!EvalToDestAtCompileTime) || ComplexByReal
+ MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime!=0)
};
typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
@@ -310,13 +319,12 @@
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
- ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
- * RhsBlasTraits::extractScalarFactor(rhs);
+ ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
// on, the other hand it is good for the cache to pack the vector anyways...
- DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
+ DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime==0
};
gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
@@ -386,7 +394,8 @@
*/
template<typename Derived>
template<typename OtherDerived>
-inline const Product<Derived, OtherDerived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const Product<Derived, OtherDerived>
MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
{
// A note regarding the function declaration: In MSVC, this function will sometimes
@@ -428,6 +437,7 @@
*/
template<typename Derived>
template<typename OtherDerived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Product<Derived,OtherDerived,LazyProduct>
MatrixBase<Derived>::lazyProduct(const MatrixBase<OtherDerived> &other) const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GenericPacketMath.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GenericPacketMath.h
index 029f8ac..cf677a1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GenericPacketMath.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GenericPacketMath.h
@@ -44,23 +44,29 @@
enum {
HasHalfPacket = 0,
- HasAdd = 1,
- HasSub = 1,
- HasMul = 1,
- HasNegate = 1,
- HasAbs = 1,
- HasArg = 0,
- HasAbs2 = 1,
- HasMin = 1,
- HasMax = 1,
- HasConj = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
HasSetLinear = 1,
- HasBlend = 0,
+ HasBlend = 0,
+ // This flag is used to indicate whether packet comparison is supported.
+ // pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
+ HasCmp = 0,
HasDiv = 0,
HasSqrt = 0,
HasRsqrt = 0,
HasExp = 0,
+ HasExpm1 = 0,
HasLog = 0,
HasLog1p = 0,
HasLog10 = 0,
@@ -81,14 +87,18 @@
HasPolygamma = 0,
HasErf = 0,
HasErfc = 0,
+ HasNdtri = 0,
+ HasBessel = 0,
HasIGamma = 0,
+ HasIGammaDerA = 0,
+ HasGammaSampleDerAlpha = 0,
HasIGammac = 0,
HasBetaInc = 0,
HasRound = 0,
+ HasRint = 0,
HasFloor = 0,
HasCeil = 0,
-
HasSign = 0
};
};
@@ -119,6 +129,22 @@
template<typename T> struct packet_traits<const T> : packet_traits<T> { };
+template<typename T> struct unpacket_traits
+{
+ typedef T type;
+ typedef T half;
+ enum
+ {
+ size = 1,
+ alignment = 1,
+ vectorizable = false,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
+
+template<typename T> struct unpacket_traits<const T> : unpacket_traits<T> { };
+
template <typename Src, typename Tgt> struct type_casting_traits {
enum {
VectorizedCast = 0,
@@ -127,6 +153,34 @@
};
};
+/** \internal Wrapper to ensure that multiple packet types can map to the same
+ same underlying vector type. */
+template<typename T, int unique_id = 0>
+struct eigen_packet_wrapper
+{
+ EIGEN_ALWAYS_INLINE operator T&() { return m_val; }
+ EIGEN_ALWAYS_INLINE operator const T&() const { return m_val; }
+ EIGEN_ALWAYS_INLINE eigen_packet_wrapper() {}
+ EIGEN_ALWAYS_INLINE eigen_packet_wrapper(const T &v) : m_val(v) {}
+ EIGEN_ALWAYS_INLINE eigen_packet_wrapper& operator=(const T &v) {
+ m_val = v;
+ return *this;
+ }
+
+ T m_val;
+};
+
+
+/** \internal A convenience utility for determining if the type is a scalar.
+ * This is used to enable some generic packet implementations.
+ */
+template<typename Packet>
+struct is_scalar {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ enum {
+ value = internal::is_same<Packet, Scalar>::value
+ };
+};
/** \internal \returns static_cast<TgtType>(a) (coeff-wise) */
template <typename SrcPacket, typename TgtPacket>
@@ -139,75 +193,406 @@
pcast(const SrcPacket& a, const SrcPacket& /*b*/) {
return static_cast<TgtPacket>(a);
}
-
template <typename SrcPacket, typename TgtPacket>
EIGEN_DEVICE_FUNC inline TgtPacket
pcast(const SrcPacket& a, const SrcPacket& /*b*/, const SrcPacket& /*c*/, const SrcPacket& /*d*/) {
return static_cast<TgtPacket>(a);
}
+template <typename SrcPacket, typename TgtPacket>
+EIGEN_DEVICE_FUNC inline TgtPacket
+pcast(const SrcPacket& a, const SrcPacket& /*b*/, const SrcPacket& /*c*/, const SrcPacket& /*d*/,
+ const SrcPacket& /*e*/, const SrcPacket& /*f*/, const SrcPacket& /*g*/, const SrcPacket& /*h*/) {
+ return static_cast<TgtPacket>(a);
+}
+
+/** \internal \returns reinterpret_cast<Target>(a) */
+template <typename Target, typename Packet>
+EIGEN_DEVICE_FUNC inline Target
+preinterpret(const Packet& a); /* { return reinterpret_cast<const Target&>(a); } */
/** \internal \returns a + b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-padd(const Packet& a,
- const Packet& b) { return a+b; }
+padd(const Packet& a, const Packet& b) { return a+b; }
+// Avoid compiler warning for boolean algebra.
+template<> EIGEN_DEVICE_FUNC inline bool
+padd(const bool& a, const bool& b) { return a || b; }
/** \internal \returns a - b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-psub(const Packet& a,
- const Packet& b) { return a-b; }
+psub(const Packet& a, const Packet& b) { return a-b; }
/** \internal \returns -a (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pnegate(const Packet& a) { return -a; }
-/** \internal \returns conj(a) (coeff-wise) */
+template<> EIGEN_DEVICE_FUNC inline bool
+pnegate(const bool& a) { return !a; }
+/** \internal \returns conj(a) (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pconj(const Packet& a) { return numext::conj(a); }
/** \internal \returns a * b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pmul(const Packet& a,
- const Packet& b) { return a*b; }
+pmul(const Packet& a, const Packet& b) { return a*b; }
+// Avoid compiler warning for boolean algebra.
+template<> EIGEN_DEVICE_FUNC inline bool
+pmul(const bool& a, const bool& b) { return a && b; }
/** \internal \returns a / b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pdiv(const Packet& a,
- const Packet& b) { return a/b; }
+pdiv(const Packet& a, const Packet& b) { return a/b; }
-/** \internal \returns the min of \a a and \a b (coeff-wise) */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pmin(const Packet& a,
- const Packet& b) { return numext::mini(a, b); }
+// In the generic case, memset to all one bits.
+template<typename Packet, typename EnableIf = void>
+struct ptrue_impl {
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& /*a*/){
+ Packet b;
+ memset(static_cast<void*>(&b), 0xff, sizeof(Packet));
+ return b;
+ }
+};
-/** \internal \returns the max of \a a and \a b (coeff-wise) */
+// For non-trivial scalars, set to Scalar(1) (i.e. a non-zero value).
+// Although this is technically not a valid bitmask, the scalar path for pselect
+// uses a comparison to zero, so this should still work in most cases. We don't
+// have another option, since the scalar type requires initialization.
+template<typename T>
+struct ptrue_impl<T,
+ typename internal::enable_if<is_scalar<T>::value && NumTraits<T>::RequireInitialization>::type > {
+ static EIGEN_DEVICE_FUNC inline T run(const T& /*a*/){
+ return T(1);
+ }
+};
+
+/** \internal \returns one bits. */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pmax(const Packet& a,
- const Packet& b) { return numext::maxi(a, b); }
+ptrue(const Packet& a) {
+ return ptrue_impl<Packet>::run(a);
+}
+
+// In the general case, memset to zero.
+template<typename Packet, typename EnableIf = void>
+struct pzero_impl {
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& /*a*/) {
+ Packet b;
+ memset(static_cast<void*>(&b), 0x00, sizeof(Packet));
+ return b;
+ }
+};
+
+// For scalars, explicitly set to Scalar(0), since the underlying representation
+// for zero may not consist of all-zero bits.
+template<typename T>
+struct pzero_impl<T,
+ typename internal::enable_if<is_scalar<T>::value>::type> {
+ static EIGEN_DEVICE_FUNC inline T run(const T& /*a*/) {
+ return T(0);
+ }
+};
+
+/** \internal \returns packet of zeros */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pzero(const Packet& a) {
+ return pzero_impl<Packet>::run(a);
+}
+
+/** \internal \returns a <= b as a bit mask */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pcmp_le(const Packet& a, const Packet& b) { return a<=b ? ptrue(a) : pzero(a); }
+
+/** \internal \returns a < b as a bit mask */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pcmp_lt(const Packet& a, const Packet& b) { return a<b ? ptrue(a) : pzero(a); }
+
+/** \internal \returns a == b as a bit mask */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pcmp_eq(const Packet& a, const Packet& b) { return a==b ? ptrue(a) : pzero(a); }
+
+/** \internal \returns a < b or a==NaN or b==NaN as a bit mask */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pcmp_lt_or_nan(const Packet& a, const Packet& b) { return a>=b ? pzero(a) : ptrue(a); }
+
+template<typename T>
+struct bit_and {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR EIGEN_ALWAYS_INLINE T operator()(const T& a, const T& b) const {
+ return a & b;
+ }
+};
+
+template<typename T>
+struct bit_or {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR EIGEN_ALWAYS_INLINE T operator()(const T& a, const T& b) const {
+ return a | b;
+ }
+};
+
+template<typename T>
+struct bit_xor {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR EIGEN_ALWAYS_INLINE T operator()(const T& a, const T& b) const {
+ return a ^ b;
+ }
+};
+
+template<typename T>
+struct bit_not {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR EIGEN_ALWAYS_INLINE T operator()(const T& a) const {
+ return ~a;
+ }
+};
+
+// Use operators &, |, ^, ~.
+template<typename T>
+struct operator_bitwise_helper {
+ EIGEN_DEVICE_FUNC static inline T bitwise_and(const T& a, const T& b) { return bit_and<T>()(a, b); }
+ EIGEN_DEVICE_FUNC static inline T bitwise_or(const T& a, const T& b) { return bit_or<T>()(a, b); }
+ EIGEN_DEVICE_FUNC static inline T bitwise_xor(const T& a, const T& b) { return bit_xor<T>()(a, b); }
+ EIGEN_DEVICE_FUNC static inline T bitwise_not(const T& a) { return bit_not<T>()(a); }
+};
+
+// Apply binary operations byte-by-byte
+template<typename T>
+struct bytewise_bitwise_helper {
+ EIGEN_DEVICE_FUNC static inline T bitwise_and(const T& a, const T& b) {
+ return binary(a, b, bit_and<unsigned char>());
+ }
+ EIGEN_DEVICE_FUNC static inline T bitwise_or(const T& a, const T& b) {
+ return binary(a, b, bit_or<unsigned char>());
+ }
+ EIGEN_DEVICE_FUNC static inline T bitwise_xor(const T& a, const T& b) {
+ return binary(a, b, bit_xor<unsigned char>());
+ }
+ EIGEN_DEVICE_FUNC static inline T bitwise_not(const T& a) {
+ return unary(a,bit_not<unsigned char>());
+ }
+
+ private:
+ template<typename Op>
+ EIGEN_DEVICE_FUNC static inline T unary(const T& a, Op op) {
+ const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
+ T c;
+ unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
+ for (size_t i = 0; i < sizeof(T); ++i) {
+ *c_ptr++ = op(*a_ptr++);
+ }
+ return c;
+ }
+
+ template<typename Op>
+ EIGEN_DEVICE_FUNC static inline T binary(const T& a, const T& b, Op op) {
+ const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
+ const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
+ T c;
+ unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
+ for (size_t i = 0; i < sizeof(T); ++i) {
+ *c_ptr++ = op(*a_ptr++, *b_ptr++);
+ }
+ return c;
+ }
+};
+
+// In the general case, use byte-by-byte manipulation.
+template<typename T, typename EnableIf = void>
+struct bitwise_helper : public bytewise_bitwise_helper<T> {};
+
+// For integers or non-trivial scalars, use binary operators.
+template<typename T>
+struct bitwise_helper<T,
+ typename internal::enable_if<
+ is_scalar<T>::value && (NumTraits<T>::IsInteger || NumTraits<T>::RequireInitialization)>::type
+ > : public operator_bitwise_helper<T> {};
+
+/** \internal \returns the bitwise and of \a a and \a b */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pand(const Packet& a, const Packet& b) {
+ return bitwise_helper<Packet>::bitwise_and(a, b);
+}
+
+/** \internal \returns the bitwise or of \a a and \a b */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+por(const Packet& a, const Packet& b) {
+ return bitwise_helper<Packet>::bitwise_or(a, b);
+}
+
+/** \internal \returns the bitwise xor of \a a and \a b */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pxor(const Packet& a, const Packet& b) {
+ return bitwise_helper<Packet>::bitwise_xor(a, b);
+}
+
+/** \internal \returns the bitwise not of \a a */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pnot(const Packet& a) {
+ return bitwise_helper<Packet>::bitwise_not(a);
+}
+
+/** \internal \returns the bitwise and of \a a and not \a b */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); }
+
+// In the general case, use bitwise select.
+template<typename Packet, typename EnableIf = void>
+struct pselect_impl {
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) {
+ return por(pand(a,mask),pandnot(b,mask));
+ }
+};
+
+// For scalars, use ternary select.
+template<typename Packet>
+struct pselect_impl<Packet,
+ typename internal::enable_if<is_scalar<Packet>::value>::type > {
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) {
+ return numext::equal_strict(mask, Packet(0)) ? b : a;
+ }
+};
+
+/** \internal \returns \a or \b for each field in packet according to \mask */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pselect(const Packet& mask, const Packet& a, const Packet& b) {
+ return pselect_impl<Packet>::run(mask, a, b);
+}
+
+template<> EIGEN_DEVICE_FUNC inline bool pselect<bool>(
+ const bool& cond, const bool& a, const bool& b) {
+ return cond ? a : b;
+}
+
+/** \internal \returns the min or of \a a and \a b (coeff-wise)
+ If either \a a or \a b are NaN, the result is implementation defined. */
+template<int NaNPropagation>
+struct pminmax_impl {
+ template <typename Packet, typename Op>
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
+ return op(a,b);
+ }
+};
+
+/** \internal \returns the min or max of \a a and \a b (coeff-wise)
+ If either \a a or \a b are NaN, NaN is returned. */
+template<>
+struct pminmax_impl<PropagateNaN> {
+ template <typename Packet, typename Op>
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
+ Packet not_nan_mask_a = pcmp_eq(a, a);
+ Packet not_nan_mask_b = pcmp_eq(b, b);
+ return pselect(not_nan_mask_a,
+ pselect(not_nan_mask_b, op(a, b), b),
+ a);
+ }
+};
+
+/** \internal \returns the min or max of \a a and \a b (coeff-wise)
+ If both \a a and \a b are NaN, NaN is returned.
+ Equivalent to std::fmin(a, b). */
+template<>
+struct pminmax_impl<PropagateNumbers> {
+ template <typename Packet, typename Op>
+ static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
+ Packet not_nan_mask_a = pcmp_eq(a, a);
+ Packet not_nan_mask_b = pcmp_eq(b, b);
+ return pselect(not_nan_mask_a,
+ pselect(not_nan_mask_b, op(a, b), a),
+ b);
+ }
+};
+
+
+#ifndef SYCL_DEVICE_ONLY
+#define EIGEN_BINARY_OP_NAN_PROPAGATION(Type, Func) Func
+#else
+#define EIGEN_BINARY_OP_NAN_PROPAGATION(Type, Func) \
+[](const Type& a, const Type& b) { \
+ return Func(a, b);}
+#endif
+
+/** \internal \returns the min of \a a and \a b (coeff-wise).
+ If \a a or \b b is NaN, the return value is implementation defined. */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pmin(const Packet& a, const Packet& b) { return numext::mini(a,b); }
+
+/** \internal \returns the min of \a a and \a b (coeff-wise).
+ NaNPropagation determines the NaN propagation semantics. */
+template <int NaNPropagation, typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) {
+ return pminmax_impl<NaNPropagation>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin<Packet>)));
+}
+
+/** \internal \returns the max of \a a and \a b (coeff-wise)
+ If \a a or \b b is NaN, the return value is implementation defined. */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); }
+
+/** \internal \returns the max of \a a and \a b (coeff-wise).
+ NaNPropagation determines the NaN propagation semantics. */
+template <int NaNPropagation, typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) {
+ return pminmax_impl<NaNPropagation>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet,(pmax<Packet>)));
+}
/** \internal \returns the absolute value of \a a */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pabs(const Packet& a) { using std::abs; return abs(a); }
+pabs(const Packet& a) { return numext::abs(a); }
+template<> EIGEN_DEVICE_FUNC inline unsigned int
+pabs(const unsigned int& a) { return a; }
+template<> EIGEN_DEVICE_FUNC inline unsigned long
+pabs(const unsigned long& a) { return a; }
+template<> EIGEN_DEVICE_FUNC inline unsigned long long
+pabs(const unsigned long long& a) { return a; }
+
+/** \internal \returns the addsub value of \a a,b */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+paddsub(const Packet& a, const Packet& b) {
+ return pselect(peven_mask(a), padd(a, b), psub(a, b));
+ }
/** \internal \returns the phase angle of \a a */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
parg(const Packet& a) { using numext::arg; return arg(a); }
-/** \internal \returns the bitwise and of \a a and \a b */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pand(const Packet& a, const Packet& b) { return a & b; }
-/** \internal \returns the bitwise or of \a a and \a b */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-por(const Packet& a, const Packet& b) { return a | b; }
+/** \internal \returns \a a logically shifted by N bits to the right */
+template<int N> EIGEN_DEVICE_FUNC inline int
+parithmetic_shift_right(const int& a) { return a >> N; }
+template<int N> EIGEN_DEVICE_FUNC inline long int
+parithmetic_shift_right(const long int& a) { return a >> N; }
-/** \internal \returns the bitwise xor of \a a and \a b */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pxor(const Packet& a, const Packet& b) { return a ^ b; }
+/** \internal \returns \a a arithmetically shifted by N bits to the right */
+template<int N> EIGEN_DEVICE_FUNC inline int
+plogical_shift_right(const int& a) { return static_cast<int>(static_cast<unsigned int>(a) >> N); }
+template<int N> EIGEN_DEVICE_FUNC inline long int
+plogical_shift_right(const long int& a) { return static_cast<long>(static_cast<unsigned long>(a) >> N); }
-/** \internal \returns the bitwise andnot of \a a and \a b */
+/** \internal \returns \a a shifted by N bits to the left */
+template<int N> EIGEN_DEVICE_FUNC inline int
+plogical_shift_left(const int& a) { return a << N; }
+template<int N> EIGEN_DEVICE_FUNC inline long int
+plogical_shift_left(const long int& a) { return a << N; }
+
+/** \internal \returns the significant and exponent of the underlying floating point numbers
+ * See https://en.cppreference.com/w/cpp/numeric/math/frexp
+ */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) {
+ int exp;
+ EIGEN_USING_STD(frexp);
+ Packet result = static_cast<Packet>(frexp(a, &exp));
+ exponent = static_cast<Packet>(exp);
+ return result;
+}
+
+/** \internal \returns a * 2^((int)exponent)
+ * See https://en.cppreference.com/w/cpp/numeric/math/ldexp
+ */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pandnot(const Packet& a, const Packet& b) { return a & (!b); }
+pldexp(const Packet &a, const Packet &exponent) {
+ EIGEN_USING_STD(ldexp)
+ return static_cast<Packet>(ldexp(a, static_cast<int>(exponent)));
+}
+
+/** \internal \returns the min of \a a and \a b (coeff-wise) */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pabsdiff(const Packet& a, const Packet& b) { return pselect(pcmp_lt(a, b), psub(b, a), psub(a, b)); }
/** \internal \returns a packet version of \a *from, from must be 16 bytes aligned */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
@@ -217,10 +602,22 @@
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
ploadu(const typename unpacket_traits<Packet>::type* from) { return *from; }
+/** \internal \returns a packet version of \a *from, (un-aligned masked load)
+ * There is no generic implementation. We only have implementations for specialized
+ * cases. Generic case should not be called.
+ */
+template<typename Packet> EIGEN_DEVICE_FUNC inline
+typename enable_if<unpacket_traits<Packet>::masked_load_available, Packet>::type
+ploadu(const typename unpacket_traits<Packet>::type* from, typename unpacket_traits<Packet>::mask_t umask);
+
/** \internal \returns a packet with constant coefficients \a a, e.g.: (a,a,a,a) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pset1(const typename unpacket_traits<Packet>::type& a) { return a; }
+/** \internal \returns a packet with constant coefficients set from bits */
+template<typename Packet,typename BitsType> EIGEN_DEVICE_FUNC inline Packet
+pset1frombits(BitsType a);
+
/** \internal \returns a packet with constant coefficients \a a[0], e.g.: (a[0],a[0],a[0],a[0]) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pload1(const typename unpacket_traits<Packet>::type *a) { return pset1<Packet>(*a); }
@@ -237,7 +634,7 @@
* For instance, for a packet of 8 elements, 2 scalars will be read from \a *from and
* replicated to form: {from[0],from[0],from[0],from[0],from[1],from[1],from[1],from[1]}
* Currently, this function is only used in matrix products.
- * For packet-size smaller or equal to 4, this function is equivalent to pload1
+ * For packet-size smaller or equal to 4, this function is equivalent to pload1
*/
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
ploadquad(const typename unpacket_traits<Packet>::type* from)
@@ -281,6 +678,20 @@
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
plset(const typename unpacket_traits<Packet>::type& a) { return a; }
+/** \internal \returns a packet with constant coefficients \a a, e.g.: (x, 0, x, 0),
+ where x is the value of all 1-bits. */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+peven_mask(const Packet& /*a*/) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ const size_t n = unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Scalar elements[n];
+ for(size_t i = 0; i < n; ++i) {
+ memset(elements+i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar));
+ }
+ return ploadu<Packet>(elements);
+}
+
+
/** \internal copy the packet \a from to \a *to, \a to must be 16 bytes aligned */
template<typename Scalar, typename Packet> EIGEN_DEVICE_FUNC inline void pstore(Scalar* to, const Packet& from)
{ (*to) = from; }
@@ -289,6 +700,15 @@
template<typename Scalar, typename Packet> EIGEN_DEVICE_FUNC inline void pstoreu(Scalar* to, const Packet& from)
{ (*to) = from; }
+/** \internal copy the packet \a from to \a *to, (un-aligned store with a mask)
+ * There is no generic implementation. We only have implementations for specialized
+ * cases. Generic case should not be called.
+ */
+template<typename Scalar, typename Packet>
+EIGEN_DEVICE_FUNC inline
+typename enable_if<unpacket_traits<Packet>::masked_store_available, void>::type
+pstoreu(Scalar* to, const Packet& from, typename unpacket_traits<Packet>::mask_t umask);
+
template<typename Scalar, typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather(const Scalar* from, Index /*stride*/)
{ return ploadu<Packet>(from); }
@@ -298,8 +718,10 @@
/** \internal tries to do cache prefetching of \a addr */
template<typename Scalar> EIGEN_DEVICE_FUNC inline void prefetch(const Scalar* addr)
{
-#ifdef __CUDA_ARCH__
-#if defined(__LP64__)
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+ // do nothing
+#elif defined(EIGEN_CUDA_ARCH)
+#if defined(__LP64__) || EIGEN_OS_WIN64
// 64-bit pointer operand constraint for inlined asm
asm(" prefetch.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
#else
@@ -311,39 +733,6 @@
#endif
}
-/** \internal \returns the first element of a packet */
-template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type pfirst(const Packet& a)
-{ return a; }
-
-/** \internal \returns a packet where the element i contains the sum of the packet of \a vec[i] */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-preduxp(const Packet* vecs) { return vecs[0]; }
-
-/** \internal \returns the sum of the elements of \a a*/
-template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux(const Packet& a)
-{ return a; }
-
-/** \internal \returns the sum of the elements of \a a by block of 4 elements.
- * For a packet {a0, a1, a2, a3, a4, a5, a6, a7}, it returns a half packet {a0+a4, a1+a5, a2+a6, a3+a7}
- * For packet-size smaller or equal to 4, this boils down to a noop.
- */
-template<typename Packet> EIGEN_DEVICE_FUNC inline
-typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type
-predux_downto4(const Packet& a)
-{ return a; }
-
-/** \internal \returns the product of the elements of \a a*/
-template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(const Packet& a)
-{ return a; }
-
-/** \internal \returns the min of the elements of \a a*/
-template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a)
-{ return a; }
-
-/** \internal \returns the max of the elements of \a a*/
-template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a)
-{ return a; }
-
/** \internal \returns the reversed elements of \a a*/
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet preverse(const Packet& a)
{ return a; }
@@ -351,10 +740,7 @@
/** \internal \returns \a a with real and imaginary part flipped (for complex type only) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pcplxflip(const Packet& a)
{
- // FIXME: uncomment the following in case we drop the internal imag and real functions.
-// using std::imag;
-// using std::real;
- return Packet(imag(a),real(a));
+ return Packet(numext::imag(a),numext::real(a));
}
/**************************
@@ -363,47 +749,51 @@
/** \internal \returns the sine of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet psin(const Packet& a) { using std::sin; return sin(a); }
+Packet psin(const Packet& a) { EIGEN_USING_STD(sin); return sin(a); }
/** \internal \returns the cosine of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet pcos(const Packet& a) { using std::cos; return cos(a); }
+Packet pcos(const Packet& a) { EIGEN_USING_STD(cos); return cos(a); }
/** \internal \returns the tan of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet ptan(const Packet& a) { using std::tan; return tan(a); }
+Packet ptan(const Packet& a) { EIGEN_USING_STD(tan); return tan(a); }
/** \internal \returns the arc sine of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet pasin(const Packet& a) { using std::asin; return asin(a); }
+Packet pasin(const Packet& a) { EIGEN_USING_STD(asin); return asin(a); }
/** \internal \returns the arc cosine of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet pacos(const Packet& a) { using std::acos; return acos(a); }
+Packet pacos(const Packet& a) { EIGEN_USING_STD(acos); return acos(a); }
/** \internal \returns the arc tangent of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet patan(const Packet& a) { using std::atan; return atan(a); }
+Packet patan(const Packet& a) { EIGEN_USING_STD(atan); return atan(a); }
/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet psinh(const Packet& a) { using std::sinh; return sinh(a); }
+Packet psinh(const Packet& a) { EIGEN_USING_STD(sinh); return sinh(a); }
/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet pcosh(const Packet& a) { using std::cosh; return cosh(a); }
+Packet pcosh(const Packet& a) { EIGEN_USING_STD(cosh); return cosh(a); }
/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet ptanh(const Packet& a) { using std::tanh; return tanh(a); }
+Packet ptanh(const Packet& a) { EIGEN_USING_STD(tanh); return tanh(a); }
/** \internal \returns the exp of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet pexp(const Packet& a) { using std::exp; return exp(a); }
+Packet pexp(const Packet& a) { EIGEN_USING_STD(exp); return exp(a); }
+
+/** \internal \returns the expm1 of \a a (coeff-wise) */
+template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+Packet pexpm1(const Packet& a) { return numext::expm1(a); }
/** \internal \returns the log of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet plog(const Packet& a) { using std::log; return log(a); }
+Packet plog(const Packet& a) { EIGEN_USING_STD(log); return log(a); }
/** \internal \returns the log1p of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
@@ -411,16 +801,24 @@
/** \internal \returns the log10 of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet plog10(const Packet& a) { using std::log10; return log10(a); }
+Packet plog10(const Packet& a) { EIGEN_USING_STD(log10); return log10(a); }
+
+/** \internal \returns the log10 of \a a (coeff-wise) */
+template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+Packet plog2(const Packet& a) {
+ typedef typename internal::unpacket_traits<Packet>::type Scalar;
+ return pmul(pset1<Packet>(Scalar(EIGEN_LOG2E)), plog(a));
+}
/** \internal \returns the square-root of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-Packet psqrt(const Packet& a) { using std::sqrt; return sqrt(a); }
+Packet psqrt(const Packet& a) { return numext::sqrt(a); }
/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet prsqrt(const Packet& a) {
- return pdiv(pset1<Packet>(1), psqrt(a));
+ typedef typename internal::unpacket_traits<Packet>::type Scalar;
+ return pdiv(pset1<Packet>(Scalar(1)), psqrt(a));
}
/** \internal \returns the rounded value of \a a (coeff-wise) */
@@ -431,15 +829,121 @@
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pfloor(const Packet& a) { using numext::floor; return floor(a); }
+/** \internal \returns the rounded value of \a a (coeff-wise) with current
+ * rounding mode */
+template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+Packet print(const Packet& a) { using numext::rint; return rint(a); }
+
/** \internal \returns the ceil of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
+/** \internal \returns the first element of a packet */
+template<typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
+pfirst(const Packet& a)
+{ return a; }
+
+/** \internal \returns the sum of the elements of upper and lower half of \a a if \a a is larger than 4.
+ * For a packet {a0, a1, a2, a3, a4, a5, a6, a7}, it returns a half packet {a0+a4, a1+a5, a2+a6, a3+a7}
+ * For packet-size smaller or equal to 4, this boils down to a noop.
+ */
+template<typename Packet>
+EIGEN_DEVICE_FUNC inline typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type
+predux_half_dowto4(const Packet& a)
+{ return a; }
+
+// Slow generic implementation of Packet reduction.
+template <typename Packet, typename Op>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
+predux_helper(const Packet& a, Op op) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ const size_t n = unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Scalar elements[n];
+ pstoreu<Scalar>(elements, a);
+ for(size_t k = n / 2; k > 0; k /= 2) {
+ for(size_t i = 0; i < k; ++i) {
+ elements[i] = op(elements[i], elements[i + k]);
+ }
+ }
+ return elements[0];
+}
+
+/** \internal \returns the sum of the elements of \a a*/
+template<typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
+predux(const Packet& a)
+{
+ return a;
+}
+
+/** \internal \returns the product of the elements of \a a */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(
+ const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmul<Scalar>)));
+}
+
+/** \internal \returns the min of the elements of \a a */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(
+ const Packet &a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<PropagateFast, Scalar>)));
+}
+
+template <int NaNPropagation, typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(
+ const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<NaNPropagation, Scalar>)));
+}
+
+/** \internal \returns the min of the elements of \a a */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(
+ const Packet &a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<PropagateFast, Scalar>)));
+}
+
+template <int NaNPropagation, typename Packet>
+EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(
+ const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<NaNPropagation, Scalar>)));
+}
+
+#undef EIGEN_BINARY_OP_NAN_PROPAGATION
+
+/** \internal \returns true if all coeffs of \a a means "true"
+ * It is supposed to be called on values returned by pcmp_*.
+ */
+// not needed yet
+// template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_all(const Packet& a)
+// { return bool(a); }
+
+/** \internal \returns true if any coeffs of \a a means "true"
+ * It is supposed to be called on values returned by pcmp_*.
+ */
+template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a)
+{
+ // Dirty but generic implementation where "true" is assumed to be non 0 and all the sames.
+ // It is expected that "true" is either:
+ // - Scalar(1)
+ // - bits full of ones (NaN for floats),
+ // - or first bit equals to 1 (1 for ints, smallest denormal for floats).
+ // For all these cases, taking the sum is just fine, and this boils down to a no-op for scalars.
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ return numext::not_equal_strict(predux(a), Scalar(0));
+}
+
/***************************************************************************
* The following functions might not have to be overwritten for vectorized types
***************************************************************************/
-/** \internal copy a packet with constant coeficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned */
+/** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned */
// NOTE: this function must really be templated on the packet type (think about different packet types for the same scalar type)
template<typename Packet>
inline void pstore1(typename unpacket_traits<Packet>::type* to, const typename unpacket_traits<Packet>::type& a)
@@ -487,47 +991,18 @@
return ploadt<Packet, LoadMode>(from);
}
-/** \internal default implementation of palign() allowing partial specialization */
-template<int Offset,typename PacketType>
-struct palign_impl
-{
- // by default data are aligned, so there is nothing to be done :)
- static inline void run(PacketType&, const PacketType&) {}
-};
-
-/** \internal update \a first using the concatenation of the packet_size minus \a Offset last elements
- * of \a first and \a Offset first elements of \a second.
- *
- * This function is currently only used to optimize matrix-vector products on unligned matrices.
- * It takes 2 packets that represent a contiguous memory array, and returns a packet starting
- * at the position \a Offset. For instance, for packets of 4 elements, we have:
- * Input:
- * - first = {f0,f1,f2,f3}
- * - second = {s0,s1,s2,s3}
- * Output:
- * - if Offset==0 then {f0,f1,f2,f3}
- * - if Offset==1 then {f1,f2,f3,s0}
- * - if Offset==2 then {f2,f3,s0,s1}
- * - if Offset==3 then {f3,s0,s1,s3}
- */
-template<int Offset,typename PacketType>
-inline void palign(PacketType& first, const PacketType& second)
-{
- palign_impl<Offset,PacketType>::run(first,second);
-}
-
/***************************************************************************
* Fast complex products (GCC generates a function call which is very slow)
***************************************************************************/
// Eigen+CUDA does not support complexes.
-#ifndef __CUDACC__
+#if !defined(EIGEN_GPUCC)
template<> inline std::complex<float> pmul(const std::complex<float>& a, const std::complex<float>& b)
-{ return std::complex<float>(real(a)*real(b) - imag(a)*imag(b), imag(a)*real(b) + real(a)*imag(b)); }
+{ return std::complex<float>(a.real()*b.real() - a.imag()*b.imag(), a.imag()*b.real() + a.real()*b.imag()); }
template<> inline std::complex<double> pmul(const std::complex<double>& a, const std::complex<double>& b)
-{ return std::complex<double>(real(a)*real(b) - imag(a)*imag(b), imag(a)*real(b) + real(a)*imag(b)); }
+{ return std::complex<double>(a.real()*b.real() - a.imag()*b.imag(), a.imag()*b.real() + a.real()*b.imag()); }
#endif
@@ -558,34 +1033,6 @@
return ifPacket.select[0] ? thenPacket : elsePacket;
}
-/** \internal \returns \a a with the first coefficient replaced by the scalar b */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pinsertfirst(const Packet& a, typename unpacket_traits<Packet>::type b)
-{
- // Default implementation based on pblend.
- // It must be specialized for higher performance.
- Selector<unpacket_traits<Packet>::size> mask;
- mask.select[0] = true;
- // This for loop should be optimized away by the compiler.
- for(Index i=1; i<unpacket_traits<Packet>::size; ++i)
- mask.select[i] = false;
- return pblend(mask, pset1<Packet>(b), a);
-}
-
-/** \internal \returns \a a with the last coefficient replaced by the scalar b */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pinsertlast(const Packet& a, typename unpacket_traits<Packet>::type b)
-{
- // Default implementation based on pblend.
- // It must be specialized for higher performance.
- Selector<unpacket_traits<Packet>::size> mask;
- // This for loop should be optimized away by the compiler.
- for(Index i=0; i<unpacket_traits<Packet>::size-1; ++i)
- mask.select[i] = false;
- mask.select[unpacket_traits<Packet>::size-1] = true;
- return pblend(mask, pset1<Packet>(b), a);
-}
-
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GlobalFunctions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GlobalFunctions.h
index 769dc25..629af94 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GlobalFunctions.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/GlobalFunctions.h
@@ -66,21 +66,31 @@
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sinh,scalar_sinh_op,hyperbolic sine,\sa ArrayBase::sinh)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cosh,scalar_cosh_op,hyperbolic cosine,\sa ArrayBase::cosh)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(tanh,scalar_tanh_op,hyperbolic tangent,\sa ArrayBase::tanh)
+#if EIGEN_HAS_CXX11_MATH
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(asinh,scalar_asinh_op,inverse hyperbolic sine,\sa ArrayBase::asinh)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(acosh,scalar_acosh_op,inverse hyperbolic cosine,\sa ArrayBase::acosh)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(atanh,scalar_atanh_op,inverse hyperbolic tangent,\sa ArrayBase::atanh)
+#endif
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(logistic,scalar_logistic_op,logistic function,\sa ArrayBase::logistic)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op,natural logarithm of the gamma function,\sa ArrayBase::lgamma)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op,derivative of lgamma,\sa ArrayBase::digamma)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op,error function,\sa ArrayBase::erf)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erfc,scalar_erfc_op,complement error function,\sa ArrayBase::erfc)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(ndtri,scalar_ndtri_op,inverse normal distribution function,\sa ArrayBase::ndtri)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(exp,scalar_exp_op,exponential,\sa ArrayBase::exp)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(expm1,scalar_expm1_op,exponential of a value minus 1,\sa ArrayBase::expm1)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(log,scalar_log_op,natural logarithm,\sa Eigen::log10 DOXCOMMA ArrayBase::log)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(log1p,scalar_log1p_op,natural logarithm of 1 plus the value,\sa ArrayBase::log1p)
- EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(log10,scalar_log10_op,base 10 logarithm,\sa Eigen::log DOXCOMMA ArrayBase::log)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(log10,scalar_log10_op,base 10 logarithm,\sa Eigen::log DOXCOMMA ArrayBase::log10)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(log2,scalar_log2_op,base 2 logarithm,\sa Eigen::log DOXCOMMA ArrayBase::log2)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(abs,scalar_abs_op,absolute value,\sa ArrayBase::abs DOXCOMMA MatrixBase::cwiseAbs)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(abs2,scalar_abs2_op,squared absolute value,\sa ArrayBase::abs2 DOXCOMMA MatrixBase::cwiseAbs2)
- EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(arg,scalar_arg_op,complex argument,\sa ArrayBase::arg)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(arg,scalar_arg_op,complex argument,\sa ArrayBase::arg DOXCOMMA MatrixBase::cwiseArg)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sqrt,scalar_sqrt_op,square root,\sa ArrayBase::sqrt DOXCOMMA MatrixBase::cwiseSqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rsqrt,scalar_rsqrt_op,reciprocal square root,\sa ArrayBase::rsqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(square,scalar_square_op,square (power 2),\sa Eigen::abs2 DOXCOMMA Eigen::pow DOXCOMMA ArrayBase::square)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cube,scalar_cube_op,cube (power 3),\sa Eigen::pow DOXCOMMA ArrayBase::cube)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rint,scalar_rint_op,nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(round,scalar_round_op,nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(floor,scalar_floor_op,nearest integer not greater than the giben value,\sa Eigen::ceil DOXCOMMA ArrayBase::floor)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(ceil,scalar_ceil_op,nearest integer not less than the giben value,\sa Eigen::floor DOXCOMMA ArrayBase::ceil)
@@ -88,7 +98,7 @@
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(isinf,scalar_isinf_op,infinite value test,\sa Eigen::isnan DOXCOMMA Eigen::isfinite DOXCOMMA ArrayBase::isinf)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(isfinite,scalar_isfinite_op,finite value test,\sa Eigen::isinf DOXCOMMA Eigen::isnan DOXCOMMA ArrayBase::isfinite)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sign,scalar_sign_op,sign (or 0),\sa ArrayBase::sign)
-
+
/** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent.
*
* \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression (\c Derived::Scalar).
@@ -102,17 +112,18 @@
inline const CwiseBinaryOp<internal::scalar_pow_op<Derived::Scalar,ScalarExponent>,Derived,Constant<ScalarExponent> >
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
#else
- template<typename Derived,typename ScalarExponent>
- inline typename internal::enable_if< !(internal::is_same<typename Derived::Scalar,ScalarExponent>::value) && EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent),
- const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,ScalarExponent,pow) >::type
- pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
- return x.derived().pow(exponent);
- }
-
- template<typename Derived>
- inline const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename Derived::Scalar,pow)
- pow(const Eigen::ArrayBase<Derived>& x, const typename Derived::Scalar& exponent) {
- return x.derived().pow(exponent);
+ template <typename Derived,typename ScalarExponent>
+ EIGEN_DEVICE_FUNC inline
+ EIGEN_MSVC10_WORKAROUND_BINARYOP_RETURN_TYPE(
+ const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<typename Derived::Scalar
+ EIGEN_COMMA ScalarExponent EIGEN_COMMA
+ EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type,pow))
+ pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent)
+ {
+ typedef typename internal::promote_scalar_arg<typename Derived::Scalar,ScalarExponent,
+ EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type PromotedExponent;
+ return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedExponent,pow)(x.derived(),
+ typename internal::plain_constant_type<Derived,PromotedExponent>::type(x.derived().rows(), x.derived().cols(), internal::scalar_constant_op<PromotedExponent>(exponent)));
}
#endif
@@ -122,21 +133,21 @@
*
* Example: \include Cwise_array_power_array.cpp
* Output: \verbinclude Cwise_array_power_array.out
- *
+ *
* \sa ArrayBase::pow()
*
* \relates ArrayBase
*/
template<typename Derived,typename ExponentDerived>
inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar, typename ExponentDerived::Scalar>, const Derived, const ExponentDerived>
- pow(const Eigen::ArrayBase<Derived>& x, const Eigen::ArrayBase<ExponentDerived>& exponents)
+ pow(const Eigen::ArrayBase<Derived>& x, const Eigen::ArrayBase<ExponentDerived>& exponents)
{
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar, typename ExponentDerived::Scalar>, const Derived, const ExponentDerived>(
x.derived(),
exponents.derived()
);
}
-
+
/** \returns an expression of the coefficient-wise power of the scalar \a x to the given array of \a exponents.
*
* This function computes the coefficient-wise power between a scalar and an array of exponents.
@@ -145,7 +156,7 @@
*
* Example: \include Cwise_scalar_power_array.cpp
* Output: \verbinclude Cwise_scalar_power_array.out
- *
+ *
* \sa ArrayBase::pow()
*
* \relates ArrayBase
@@ -155,21 +166,17 @@
inline const CwiseBinaryOp<internal::scalar_pow_op<Scalar,Derived::Scalar>,Constant<Scalar>,Derived>
pow(const Scalar& x,const Eigen::ArrayBase<Derived>& x);
#else
- template<typename Scalar, typename Derived>
- inline typename internal::enable_if< !(internal::is_same<typename Derived::Scalar,Scalar>::value) && EIGEN_SCALAR_BINARY_SUPPORTED(pow,Scalar,typename Derived::Scalar),
- const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,pow) >::type
- pow(const Scalar& x, const Eigen::ArrayBase<Derived>& exponents)
- {
- return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,pow)(
- typename internal::plain_constant_type<Derived,Scalar>::type(exponents.rows(), exponents.cols(), x), exponents.derived() );
- }
-
- template<typename Derived>
- inline const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename Derived::Scalar,Derived,pow)
- pow(const typename Derived::Scalar& x, const Eigen::ArrayBase<Derived>& exponents)
- {
- return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename Derived::Scalar,Derived,pow)(
- typename internal::plain_constant_type<Derived,typename Derived::Scalar>::type(exponents.rows(), exponents.cols(), x), exponents.derived() );
+ template <typename Scalar, typename Derived>
+ EIGEN_DEVICE_FUNC inline
+ EIGEN_MSVC10_WORKAROUND_BINARYOP_RETURN_TYPE(
+ const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename internal::promote_scalar_arg<typename Derived::Scalar
+ EIGEN_COMMA Scalar EIGEN_COMMA
+ EIGEN_SCALAR_BINARY_SUPPORTED(pow,Scalar,typename Derived::Scalar)>::type,Derived,pow))
+ pow(const Scalar& x, const Eigen::ArrayBase<Derived>& exponents) {
+ typedef typename internal::promote_scalar_arg<typename Derived::Scalar,Scalar,
+ EIGEN_SCALAR_BINARY_SUPPORTED(pow,Scalar,typename Derived::Scalar)>::type PromotedScalar;
+ return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(PromotedScalar,Derived,pow)(
+ typename internal::plain_constant_type<Derived,PromotedScalar>::type(exponents.derived().rows(), exponents.derived().cols(), internal::scalar_constant_op<PromotedScalar>(x)), exponents.derived());
}
#endif
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IO.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IO.h
index da7fd6c..e81c315 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IO.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IO.h
@@ -41,6 +41,7 @@
* - \b rowSuffix string printed at the end of each row
* - \b matPrefix string printed at the beginning of the matrix
* - \b matSuffix string printed at the end of the matrix
+ * - \b fill character printed to fill the empty space in aligned columns
*
* Example: \include IOFormat.cpp
* Output: \verbinclude IOFormat.out
@@ -53,9 +54,9 @@
IOFormat(int _precision = StreamPrecision, int _flags = 0,
const std::string& _coeffSeparator = " ",
const std::string& _rowSeparator = "\n", const std::string& _rowPrefix="", const std::string& _rowSuffix="",
- const std::string& _matPrefix="", const std::string& _matSuffix="")
+ const std::string& _matPrefix="", const std::string& _matSuffix="", const char _fill=' ')
: matPrefix(_matPrefix), matSuffix(_matSuffix), rowPrefix(_rowPrefix), rowSuffix(_rowSuffix), rowSeparator(_rowSeparator),
- rowSpacer(""), coeffSeparator(_coeffSeparator), precision(_precision), flags(_flags)
+ rowSpacer(""), coeffSeparator(_coeffSeparator), fill(_fill), precision(_precision), flags(_flags)
{
// TODO check if rowPrefix, rowSuffix or rowSeparator contains a newline
// don't add rowSpacer if columns are not to be aligned
@@ -71,6 +72,7 @@
std::string matPrefix, matSuffix;
std::string rowPrefix, rowSuffix, rowSeparator, rowSpacer;
std::string coeffSeparator;
+ char fill;
int precision;
int flags;
};
@@ -128,6 +130,9 @@
template<typename Derived>
std::ostream & print_matrix(std::ostream & s, const Derived& _m, const IOFormat& fmt)
{
+ using internal::is_same;
+ using internal::conditional;
+
if(_m.size() == 0)
{
s << fmt.matPrefix << fmt.matSuffix;
@@ -136,6 +141,22 @@
typename Derived::Nested m = _m;
typedef typename Derived::Scalar Scalar;
+ typedef typename
+ conditional<
+ is_same<Scalar, char>::value ||
+ is_same<Scalar, unsigned char>::value ||
+ is_same<Scalar, numext::int8_t>::value ||
+ is_same<Scalar, numext::uint8_t>::value,
+ int,
+ typename conditional<
+ is_same<Scalar, std::complex<char> >::value ||
+ is_same<Scalar, std::complex<unsigned char> >::value ||
+ is_same<Scalar, std::complex<numext::int8_t> >::value ||
+ is_same<Scalar, std::complex<numext::uint8_t> >::value,
+ std::complex<int>,
+ const Scalar&
+ >::type
+ >::type PrintType;
Index width = 0;
@@ -172,23 +193,31 @@
{
std::stringstream sstr;
sstr.copyfmt(s);
- sstr << m.coeff(i,j);
+ sstr << static_cast<PrintType>(m.coeff(i,j));
width = std::max<Index>(width, Index(sstr.str().length()));
}
}
+ std::streamsize old_width = s.width();
+ char old_fill_character = s.fill();
s << fmt.matPrefix;
for(Index i = 0; i < m.rows(); ++i)
{
if (i)
s << fmt.rowSpacer;
s << fmt.rowPrefix;
- if(width) s.width(width);
- s << m.coeff(i, 0);
+ if(width) {
+ s.fill(fmt.fill);
+ s.width(width);
+ }
+ s << static_cast<PrintType>(m.coeff(i, 0));
for(Index j = 1; j < m.cols(); ++j)
{
s << fmt.coeffSeparator;
- if (width) s.width(width);
- s << m.coeff(i, j);
+ if(width) {
+ s.fill(fmt.fill);
+ s.width(width);
+ }
+ s << static_cast<PrintType>(m.coeff(i, j));
}
s << fmt.rowSuffix;
if( i < m.rows() - 1)
@@ -196,6 +225,10 @@
}
s << fmt.matSuffix;
if(explicit_precision) s.precision(old_precision);
+ if(width) {
+ s.fill(old_fill_character);
+ s.width(old_width);
+ }
return s;
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IndexedView.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IndexedView.h
new file mode 100644
index 0000000..0847625
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/IndexedView.h
@@ -0,0 +1,237 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_INDEXED_VIEW_H
+#define EIGEN_INDEXED_VIEW_H
+
+namespace Eigen {
+
+namespace internal {
+
+template<typename XprType, typename RowIndices, typename ColIndices>
+struct traits<IndexedView<XprType, RowIndices, ColIndices> >
+ : traits<XprType>
+{
+ enum {
+ RowsAtCompileTime = int(array_size<RowIndices>::value),
+ ColsAtCompileTime = int(array_size<ColIndices>::value),
+ MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : Dynamic,
+ MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : Dynamic,
+
+ XprTypeIsRowMajor = (int(traits<XprType>::Flags)&RowMajorBit) != 0,
+ IsRowMajor = (MaxRowsAtCompileTime==1&&MaxColsAtCompileTime!=1) ? 1
+ : (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
+ : XprTypeIsRowMajor,
+
+ RowIncr = int(get_compile_time_incr<RowIndices>::value),
+ ColIncr = int(get_compile_time_incr<ColIndices>::value),
+ InnerIncr = IsRowMajor ? ColIncr : RowIncr,
+ OuterIncr = IsRowMajor ? RowIncr : ColIncr,
+
+ HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
+ XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret) : int(outer_stride_at_compile_time<XprType>::ret),
+ XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret) : int(inner_stride_at_compile_time<XprType>::ret),
+
+ InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
+ IsBlockAlike = InnerIncr==1 && OuterIncr==1,
+ IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange<InnerSize>,typename conditional<XprTypeIsRowMajor,ColIndices,RowIndices>::type>::value,
+
+ InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr,
+ OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr,
+
+ ReturnAsScalar = is_same<RowIndices,SingleRange>::value && is_same<ColIndices,SingleRange>::value,
+ ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike,
+ ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock),
+
+ // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag,
+ // but this is too strict regarding negative strides...
+ DirectAccessMask = (int(InnerIncr)!=UndefinedIncr && int(OuterIncr)!=UndefinedIncr && InnerIncr>=0 && OuterIncr>=0) ? DirectAccessBit : 0,
+ FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
+ FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
+ FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
+ Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask )) | FlagsLvalueBit | FlagsRowMajorBit | FlagsLinearAccessBit
+ };
+
+ typedef Block<XprType,RowsAtCompileTime,ColsAtCompileTime,IsInnerPannel> BlockType;
+};
+
+}
+
+template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
+class IndexedViewImpl;
+
+
+/** \class IndexedView
+ * \ingroup Core_Module
+ *
+ * \brief Expression of a non-sequential sub-matrix defined by arbitrary sequences of row and column indices
+ *
+ * \tparam XprType the type of the expression in which we are taking the intersections of sub-rows and sub-columns
+ * \tparam RowIndices the type of the object defining the sequence of row indices
+ * \tparam ColIndices the type of the object defining the sequence of column indices
+ *
+ * This class represents an expression of a sub-matrix (or sub-vector) defined as the intersection
+ * of sub-sets of rows and columns, that are themself defined by generic sequences of row indices \f$ \{r_0,r_1,..r_{m-1}\} \f$
+ * and column indices \f$ \{c_0,c_1,..c_{n-1} \}\f$. Let \f$ A \f$ be the nested matrix, then the resulting matrix \f$ B \f$ has \c m
+ * rows and \c n columns, and its entries are given by: \f$ B(i,j) = A(r_i,c_j) \f$.
+ *
+ * The \c RowIndices and \c ColIndices types must be compatible with the following API:
+ * \code
+ * <integral type> operator[](Index) const;
+ * Index size() const;
+ * \endcode
+ *
+ * Typical supported types thus include:
+ * - std::vector<int>
+ * - std::valarray<int>
+ * - std::array<int>
+ * - Plain C arrays: int[N]
+ * - Eigen::ArrayXi
+ * - decltype(ArrayXi::LinSpaced(...))
+ * - Any view/expressions of the previous types
+ * - Eigen::ArithmeticSequence
+ * - Eigen::internal::AllRange (helper for Eigen::all)
+ * - Eigen::internal::SingleRange (helper for single index)
+ * - etc.
+ *
+ * In typical usages of %Eigen, this class should never be used directly. It is the return type of
+ * DenseBase::operator()(const RowIndices&, const ColIndices&).
+ *
+ * \sa class Block
+ */
+template<typename XprType, typename RowIndices, typename ColIndices>
+class IndexedView : public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>
+{
+public:
+ typedef typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base Base;
+ EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
+ EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
+
+ typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
+ typedef typename internal::remove_all<XprType>::type NestedExpression;
+
+ template<typename T0, typename T1>
+ IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
+ : m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices)
+ {}
+
+ /** \returns number of rows */
+ Index rows() const { return internal::size(m_rowIndices); }
+
+ /** \returns number of columns */
+ Index cols() const { return internal::size(m_colIndices); }
+
+ /** \returns the nested expression */
+ const typename internal::remove_all<XprType>::type&
+ nestedExpression() const { return m_xpr; }
+
+ /** \returns the nested expression */
+ typename internal::remove_reference<XprType>::type&
+ nestedExpression() { return m_xpr; }
+
+ /** \returns a const reference to the object storing/generating the row indices */
+ const RowIndices& rowIndices() const { return m_rowIndices; }
+
+ /** \returns a const reference to the object storing/generating the column indices */
+ const ColIndices& colIndices() const { return m_colIndices; }
+
+protected:
+ MatrixTypeNested m_xpr;
+ RowIndices m_rowIndices;
+ ColIndices m_colIndices;
+};
+
+
+// Generic API dispatcher
+template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
+class IndexedViewImpl
+ : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type
+{
+public:
+ typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type Base;
+};
+
+namespace internal {
+
+
+template<typename ArgType, typename RowIndices, typename ColIndices>
+struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
+ : evaluator_base<IndexedView<ArgType, RowIndices, ColIndices> >
+{
+ typedef IndexedView<ArgType, RowIndices, ColIndices> XprType;
+
+ enum {
+ CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of row/col index */,
+
+ FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1) ? LinearAccessBit : 0,
+
+ FlagsRowMajorBit = traits<XprType>::FlagsRowMajorBit,
+
+ Flags = (evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit /*| LinearAccessBit | DirectAccessBit*/)) | FlagsLinearAccessBit | FlagsRowMajorBit,
+
+ Alignment = 0
+ };
+
+ EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
+ {
+ EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
+ }
+
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ CoeffReturnType coeff(Index row, Index col) const
+ {
+ return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Scalar& coeffRef(Index row, Index col)
+ {
+ return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Scalar& coeffRef(Index index)
+ {
+ EIGEN_STATIC_ASSERT_LVALUE(XprType)
+ Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
+ Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
+ return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Scalar& coeffRef(Index index) const
+ {
+ Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
+ Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
+ return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const CoeffReturnType coeff(Index index) const
+ {
+ Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
+ Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
+ return m_argImpl.coeff( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
+ }
+
+protected:
+
+ evaluator<ArgType> m_argImpl;
+ const XprType& m_xpr;
+
+};
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_INDEXED_VIEW_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Inverse.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Inverse.h
index b76f043..c514438 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Inverse.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Inverse.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2014-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -10,7 +10,7 @@
#ifndef EIGEN_INVERSE_H
#define EIGEN_INVERSE_H
-namespace Eigen {
+namespace Eigen {
template<typename XprType,typename StorageKind> class InverseImpl;
@@ -44,19 +44,18 @@
{
public:
typedef typename XprType::StorageIndex StorageIndex;
- typedef typename XprType::PlainObject PlainObject;
typedef typename XprType::Scalar Scalar;
typedef typename internal::ref_selector<XprType>::type XprTypeNested;
typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned;
typedef typename internal::ref_selector<Inverse>::type Nested;
typedef typename internal::remove_all<XprType>::type NestedExpression;
-
+
explicit EIGEN_DEVICE_FUNC Inverse(const XprType &xpr)
: m_xpr(xpr)
{}
- EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
- EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
EIGEN_DEVICE_FUNC const XprTypeNestedCleaned& nestedExpression() const { return m_xpr; }
@@ -82,7 +81,7 @@
/** \internal
* \brief Default evaluator for Inverse expression.
- *
+ *
* This default evaluator for Inverse expression simply evaluate the inverse into a temporary
* by a call to internal::call_assignment_no_alias.
* Therefore, inverse implementers only have to specialize Assignment<Dst,Inverse<...>, ...> for
@@ -97,7 +96,7 @@
typedef Inverse<ArgType> InverseType;
typedef typename InverseType::PlainObject PlainObject;
typedef evaluator<PlainObject> Base;
-
+
enum { Flags = Base::Flags | EvalBeforeNestingBit };
unary_evaluator(const InverseType& inv_xpr)
@@ -106,11 +105,11 @@
::new (static_cast<Base*>(this)) Base(m_result);
internal::call_assignment_no_alias(m_result, inv_xpr);
}
-
+
protected:
PlainObject m_result;
};
-
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Map.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Map.h
index 548bf9a..218cc15 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Map.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Map.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_MAP_H
#define EIGEN_MAP_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
template<typename PlainObjectType, int MapOptions, typename StrideType>
@@ -47,7 +47,7 @@
* \brief A matrix or vector expression mapping an existing array of data.
*
* \tparam PlainObjectType the equivalent matrix type of the mapped data
- * \tparam MapOptions specifies the pointer alignment in bytes. It can be: \c #Aligned128, , \c #Aligned64, \c #Aligned32, \c #Aligned16, \c #Aligned8 or \c #Unaligned.
+ * \tparam MapOptions specifies the pointer alignment in bytes. It can be: \c #Aligned128, \c #Aligned64, \c #Aligned32, \c #Aligned16, \c #Aligned8 or \c #Unaligned.
* The default is \c #Unaligned.
* \tparam StrideType optionally specifies strides. By default, Map assumes the memory layout
* of an ordinary, contiguous array. This can be overridden by specifying strides.
@@ -104,19 +104,19 @@
EIGEN_DEVICE_FUNC
inline PointerType cast_to_pointer_type(PointerArgType ptr) { return ptr; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index innerStride() const
{
return StrideType::InnerStrideAtCompileTime != 0 ? m_stride.inner() : 1;
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index outerStride() const
{
- return int(StrideType::OuterStrideAtCompileTime) != 0 ? m_stride.outer()
- : int(internal::traits<Map>::OuterStrideAtCompileTime) != Dynamic ? Index(internal::traits<Map>::OuterStrideAtCompileTime)
+ return StrideType::OuterStrideAtCompileTime != 0 ? m_stride.outer()
+ : internal::traits<Map>::OuterStrideAtCompileTime != Dynamic ? Index(internal::traits<Map>::OuterStrideAtCompileTime)
: IsVectorAtCompileTime ? (this->size() * innerStride())
- : (int(Flags)&RowMajorBit) ? (this->cols() * innerStride())
+ : int(Flags)&RowMajorBit ? (this->cols() * innerStride())
: (this->rows() * innerStride());
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MapBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MapBase.h
index 92c3b28..d856447 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MapBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MapBase.h
@@ -15,7 +15,7 @@
EIGEN_STATIC_ASSERT((int(internal::evaluator<Derived>::Flags) & LinearAccessBit) || Derived::IsVectorAtCompileTime, \
YOU_ARE_TRYING_TO_USE_AN_INDEX_BASED_ACCESSOR_ON_AN_EXPRESSION_THAT_DOES_NOT_SUPPORT_THAT)
-namespace Eigen {
+namespace Eigen {
/** \ingroup Core_Module
*
@@ -87,9 +87,11 @@
typedef typename Base::CoeffReturnType CoeffReturnType;
/** \copydoc DenseBase::rows() */
- EIGEN_DEVICE_FUNC inline Index rows() const { return m_rows.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_rows.value(); }
/** \copydoc DenseBase::cols() */
- EIGEN_DEVICE_FUNC inline Index cols() const { return m_cols.value(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_cols.value(); }
/** Returns a pointer to the first coefficient of the matrix or vector.
*
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctions.h
index b249ce0..61b78f4 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctions.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctions.h
@@ -2,6 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2006-2010 Benoit Jacob <jacob.benoit.1@gmail.com>
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -10,10 +11,11 @@
#ifndef EIGEN_MATHFUNCTIONS_H
#define EIGEN_MATHFUNCTIONS_H
-// source: http://www.geom.uiuc.edu/~huberty/math5337/groupe/digits.html
// TODO this should better be moved to NumTraits
-#define EIGEN_PI 3.141592653589793238462643383279502884197169399375105820974944592307816406L
-
+// Source: WolframAlpha
+#define EIGEN_PI 3.141592653589793238462643383279502884197169399375105820974944592307816406L
+#define EIGEN_LOG2E 1.442695040888963407359924681001892137426645954152985934135449406931109219L
+#define EIGEN_LN2 0.693147180559945309417232121458176568075500134360255254120680009493393621L
namespace Eigen {
@@ -97,7 +99,7 @@
template<typename Scalar> struct real_impl : real_default_impl<Scalar> {};
-#ifdef __CUDA_ARCH__
+#if defined(EIGEN_GPU_COMPILE_PHASE)
template<typename T>
struct real_impl<std::complex<T> >
{
@@ -145,7 +147,7 @@
template<typename Scalar> struct imag_impl : imag_default_impl<Scalar> {};
-#ifdef __CUDA_ARCH__
+#if defined(EIGEN_GPU_COMPILE_PHASE)
template<typename T>
struct imag_impl<std::complex<T> >
{
@@ -213,12 +215,12 @@
template<typename Scalar>
struct imag_ref_default_impl<Scalar, false>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline Scalar run(Scalar&)
{
return Scalar(0);
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline const Scalar run(const Scalar&)
{
return Scalar(0);
@@ -239,7 +241,7 @@
****************************************************************************/
template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
-struct conj_impl
+struct conj_default_impl
{
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
@@ -249,7 +251,7 @@
};
template<typename Scalar>
-struct conj_impl<Scalar,true>
+struct conj_default_impl<Scalar,true>
{
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
@@ -259,6 +261,9 @@
}
};
+template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
+struct conj_impl : conj_default_impl<Scalar, IsComplex> {};
+
template<typename Scalar>
struct conj_retval
{
@@ -287,7 +292,7 @@
EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x)
{
- return real(x)*real(x) + imag(x)*imag(x);
+ return x.real()*x.real() + x.imag()*x.imag();
}
};
@@ -309,18 +314,80 @@
};
/****************************************************************************
+* Implementation of sqrt/rsqrt *
+****************************************************************************/
+
+template<typename Scalar>
+struct sqrt_impl
+{
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE Scalar run(const Scalar& x)
+ {
+ EIGEN_USING_STD(sqrt);
+ return sqrt(x);
+ }
+};
+
+// Complex sqrt defined in MathFunctionsImpl.h.
+template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& a_x);
+
+// Custom implementation is faster than `std::sqrt`, works on
+// GPU, and correctly handles special cases (unlike MSVC).
+template<typename T>
+struct sqrt_impl<std::complex<T> >
+{
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
+ {
+ return complex_sqrt<T>(x);
+ }
+};
+
+template<typename Scalar>
+struct sqrt_retval
+{
+ typedef Scalar type;
+};
+
+// Default implementation relies on numext::sqrt, at bottom of file.
+template<typename T>
+struct rsqrt_impl;
+
+// Complex rsqrt defined in MathFunctionsImpl.h.
+template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& a_x);
+
+template<typename T>
+struct rsqrt_impl<std::complex<T> >
+{
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
+ {
+ return complex_rsqrt<T>(x);
+ }
+};
+
+template<typename Scalar>
+struct rsqrt_retval
+{
+ typedef Scalar type;
+};
+
+/****************************************************************************
* Implementation of norm1 *
****************************************************************************/
template<typename Scalar, bool IsComplex>
-struct norm1_default_impl
+struct norm1_default_impl;
+
+template<typename Scalar>
+struct norm1_default_impl<Scalar,true>
{
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x)
{
- EIGEN_USING_STD_MATH(abs);
- return abs(real(x)) + abs(imag(x));
+ EIGEN_USING_STD(abs);
+ return abs(x.real()) + abs(x.imag());
}
};
@@ -330,7 +397,7 @@
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
{
- EIGEN_USING_STD_MATH(abs);
+ EIGEN_USING_STD(abs);
return abs(x);
}
};
@@ -360,7 +427,7 @@
* Implementation of cast *
****************************************************************************/
-template<typename OldType, typename NewType>
+template<typename OldType, typename NewType, typename EnableIf = void>
struct cast_impl
{
EIGEN_DEVICE_FUNC
@@ -370,6 +437,22 @@
}
};
+// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
+// generating warnings on clang. Here we explicitly cast the real component.
+template<typename OldType, typename NewType>
+struct cast_impl<OldType, NewType,
+ typename internal::enable_if<
+ !NumTraits<OldType>::IsComplex && NumTraits<NewType>::IsComplex
+ >::type>
+{
+ EIGEN_DEVICE_FUNC
+ static inline NewType run(const OldType& x)
+ {
+ typedef typename NumTraits<NewType>::Real NewReal;
+ return static_cast<NewType>(static_cast<NewReal>(x));
+ }
+};
+
// here, for once, we're plainly returning NewType: we don't want cast to do weird things.
template<typename OldType, typename NewType>
@@ -383,29 +466,59 @@
* Implementation of round *
****************************************************************************/
-#if EIGEN_HAS_CXX11_MATH
- template<typename Scalar>
- struct round_impl {
- static inline Scalar run(const Scalar& x)
- {
- EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
- using std::round;
- return round(x);
- }
- };
-#else
- template<typename Scalar>
- struct round_impl
+template<typename Scalar>
+struct round_impl
+{
+ EIGEN_DEVICE_FUNC
+ static inline Scalar run(const Scalar& x)
{
- static inline Scalar run(const Scalar& x)
- {
- EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
- EIGEN_USING_STD_MATH(floor);
- EIGEN_USING_STD_MATH(ceil);
- return (x > Scalar(0)) ? floor(x + Scalar(0.5)) : ceil(x - Scalar(0.5));
- }
- };
+ EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
+#if EIGEN_HAS_CXX11_MATH
+ EIGEN_USING_STD(round);
#endif
+ return Scalar(round(x));
+ }
+};
+
+#if !EIGEN_HAS_CXX11_MATH
+#if EIGEN_HAS_C99_MATH
+// Use ::roundf for float.
+template<>
+struct round_impl<float> {
+ EIGEN_DEVICE_FUNC
+ static inline float run(const float& x)
+ {
+ return ::roundf(x);
+ }
+};
+#else
+template<typename Scalar>
+struct round_using_floor_ceil_impl
+{
+ EIGEN_DEVICE_FUNC
+ static inline Scalar run(const Scalar& x)
+ {
+ EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
+ // Without C99 round/roundf, resort to floor/ceil.
+ EIGEN_USING_STD(floor);
+ EIGEN_USING_STD(ceil);
+ // If not enough precision to resolve a decimal at all, return the input.
+ // Otherwise, adding 0.5 can trigger an increment by 1.
+ const Scalar limit = Scalar(1ull << (NumTraits<Scalar>::digits() - 1));
+ if (x >= limit || x <= -limit) {
+ return x;
+ }
+ return (x > Scalar(0)) ? Scalar(floor(x + Scalar(0.5))) : Scalar(ceil(x - Scalar(0.5)));
+ }
+};
+
+template<>
+struct round_impl<float> : round_using_floor_ceil_impl<float> {};
+
+template<>
+struct round_impl<double> : round_using_floor_ceil_impl<double> {};
+#endif // EIGEN_HAS_C99_MATH
+#endif // !EIGEN_HAS_CXX11_MATH
template<typename Scalar>
struct round_retval
@@ -414,43 +527,112 @@
};
/****************************************************************************
+* Implementation of rint *
+****************************************************************************/
+
+template<typename Scalar>
+struct rint_impl {
+ EIGEN_DEVICE_FUNC
+ static inline Scalar run(const Scalar& x)
+ {
+ EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
+#if EIGEN_HAS_CXX11_MATH
+ EIGEN_USING_STD(rint);
+#endif
+ return rint(x);
+ }
+};
+
+#if !EIGEN_HAS_CXX11_MATH
+template<>
+struct rint_impl<double> {
+ EIGEN_DEVICE_FUNC
+ static inline double run(const double& x)
+ {
+ return ::rint(x);
+ }
+};
+template<>
+struct rint_impl<float> {
+ EIGEN_DEVICE_FUNC
+ static inline float run(const float& x)
+ {
+ return ::rintf(x);
+ }
+};
+#endif
+
+template<typename Scalar>
+struct rint_retval
+{
+ typedef Scalar type;
+};
+
+/****************************************************************************
* Implementation of arg *
****************************************************************************/
-#if EIGEN_HAS_CXX11_MATH
- template<typename Scalar>
- struct arg_impl {
- static inline Scalar run(const Scalar& x)
- {
- EIGEN_USING_STD_MATH(arg);
- return arg(x);
- }
- };
+// Visual Studio 2017 has a bug where arg(float) returns 0 for negative inputs.
+// This seems to be fixed in VS 2019.
+#if EIGEN_HAS_CXX11_MATH && (!EIGEN_COMP_MSVC || EIGEN_COMP_MSVC >= 1920)
+// std::arg is only defined for types of std::complex, or integer types or float/double/long double
+template<typename Scalar,
+ bool HasStdImpl = NumTraits<Scalar>::IsComplex || is_integral<Scalar>::value
+ || is_same<Scalar, float>::value || is_same<Scalar, double>::value
+ || is_same<Scalar, long double>::value >
+struct arg_default_impl;
+
+template<typename Scalar>
+struct arg_default_impl<Scalar, true> {
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
+ {
+ #if defined(EIGEN_HIP_DEVICE_COMPILE)
+ // HIP does not seem to have a native device side implementation for the math routine "arg"
+ using std::arg;
+ #else
+ EIGEN_USING_STD(arg);
+ #endif
+ return static_cast<RealScalar>(arg(x));
+ }
+};
+
+// Must be non-complex floating-point type (e.g. half/bfloat16).
+template<typename Scalar>
+struct arg_default_impl<Scalar, false> {
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
+ {
+ return (x < Scalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0);
+ }
+};
#else
- template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
- struct arg_default_impl
+template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
+struct arg_default_impl
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
{
- typedef typename NumTraits<Scalar>::Real RealScalar;
- EIGEN_DEVICE_FUNC
- static inline RealScalar run(const Scalar& x)
- {
- return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); }
- };
+ return (x < RealScalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0);
+ }
+};
- template<typename Scalar>
- struct arg_default_impl<Scalar,true>
+template<typename Scalar>
+struct arg_default_impl<Scalar,true>
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
{
- typedef typename NumTraits<Scalar>::Real RealScalar;
- EIGEN_DEVICE_FUNC
- static inline RealScalar run(const Scalar& x)
- {
- EIGEN_USING_STD_MATH(arg);
- return arg(x);
- }
- };
-
- template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {};
+ EIGEN_USING_STD(arg);
+ return arg(x);
+ }
+};
#endif
+template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {};
template<typename Scalar>
struct arg_retval
@@ -459,6 +641,80 @@
};
/****************************************************************************
+* Implementation of expm1 *
+****************************************************************************/
+
+// This implementation is based on GSL Math's expm1.
+namespace std_fallback {
+ // fallback expm1 implementation in case there is no expm1(Scalar) function in namespace of Scalar,
+ // or that there is no suitable std::expm1 function available. Implementation
+ // attributed to Kahan. See: http://www.plunk.org/~hatch/rightway.php.
+ template<typename Scalar>
+ EIGEN_DEVICE_FUNC inline Scalar expm1(const Scalar& x) {
+ EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar)
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+
+ EIGEN_USING_STD(exp);
+ Scalar u = exp(x);
+ if (numext::equal_strict(u, Scalar(1))) {
+ return x;
+ }
+ Scalar um1 = u - RealScalar(1);
+ if (numext::equal_strict(um1, Scalar(-1))) {
+ return RealScalar(-1);
+ }
+
+ EIGEN_USING_STD(log);
+ Scalar logu = log(u);
+ return numext::equal_strict(u, logu) ? u : (u - RealScalar(1)) * x / logu;
+ }
+}
+
+template<typename Scalar>
+struct expm1_impl {
+ EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x)
+ {
+ EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar)
+ #if EIGEN_HAS_CXX11_MATH
+ using std::expm1;
+ #else
+ using std_fallback::expm1;
+ #endif
+ return expm1(x);
+ }
+};
+
+template<typename Scalar>
+struct expm1_retval
+{
+ typedef Scalar type;
+};
+
+/****************************************************************************
+* Implementation of log *
+****************************************************************************/
+
+// Complex log defined in MathFunctionsImpl.h.
+template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z);
+
+template<typename Scalar>
+struct log_impl {
+ EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x)
+ {
+ EIGEN_USING_STD(log);
+ return static_cast<Scalar>(log(x));
+ }
+};
+
+template<typename Scalar>
+struct log_impl<std::complex<Scalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<Scalar> run(const std::complex<Scalar>& z)
+ {
+ return complex_log(z);
+ }
+};
+
+/****************************************************************************
* Implementation of log1p *
****************************************************************************/
@@ -469,25 +725,38 @@
EIGEN_DEVICE_FUNC inline Scalar log1p(const Scalar& x) {
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar)
typedef typename NumTraits<Scalar>::Real RealScalar;
- EIGEN_USING_STD_MATH(log);
+ EIGEN_USING_STD(log);
Scalar x1p = RealScalar(1) + x;
- return numext::equal_strict(x1p, Scalar(1)) ? x : x * ( log(x1p) / (x1p - RealScalar(1)) );
+ Scalar log_1p = log_impl<Scalar>::run(x1p);
+ const bool is_small = numext::equal_strict(x1p, Scalar(1));
+ const bool is_inf = numext::equal_strict(x1p, log_1p);
+ return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1)));
}
}
template<typename Scalar>
struct log1p_impl {
- static inline Scalar run(const Scalar& x)
+ EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x)
{
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar)
#if EIGEN_HAS_CXX11_MATH
using std::log1p;
- #endif
+ #else
using std_fallback::log1p;
+ #endif
return log1p(x);
}
};
+// Specialization for complex types that are not supported by std::log1p.
+template <typename RealScalar>
+struct log1p_impl<std::complex<RealScalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(
+ const std::complex<RealScalar>& x) {
+ EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar)
+ return std_fallback::log1p(x);
+ }
+};
template<typename Scalar>
struct log1p_retval
@@ -506,7 +775,7 @@
typedef typename ScalarBinaryOpTraits<ScalarX,ScalarY,internal::scalar_pow_op<ScalarX,ScalarY> >::ReturnType result_type;
static EIGEN_DEVICE_FUNC inline result_type run(const ScalarX& x, const ScalarY& y)
{
- EIGEN_USING_STD_MATH(pow);
+ EIGEN_USING_STD(pow);
return pow(x, y);
}
};
@@ -662,8 +931,8 @@
{
static inline Scalar run(const Scalar& x, const Scalar& y)
{
- return Scalar(random(real(x), real(y)),
- random(imag(x), imag(y)));
+ return Scalar(random(x.real(), y.real()),
+ random(x.imag(), y.imag()));
}
static inline Scalar run()
{
@@ -684,7 +953,7 @@
return EIGEN_MATHFUNC_IMPL(random, Scalar)::run();
}
-// Implementatin of is* functions
+// Implementation of is* functions
// std::is* do not work with fast-math and gcc, std::is* are available on MSVC 2013 and newer, as well as in clang.
#if (EIGEN_HAS_CXX11_MATH && !(EIGEN_COMP_GNUC_STRICT && __FINITE_MATH_ONLY__)) || (EIGEN_COMP_MSVC>=1800) || (EIGEN_COMP_CLANG)
@@ -713,7 +982,7 @@
typename internal::enable_if<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>::type
isfinite_impl(const T& x)
{
- #ifdef __CUDA_ARCH__
+ #if defined(EIGEN_GPU_COMPILE_PHASE)
return (::isfinite)(x);
#elif EIGEN_USE_STD_FPCLASSIFY
using std::isfinite;
@@ -728,7 +997,7 @@
typename internal::enable_if<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>::type
isinf_impl(const T& x)
{
- #ifdef __CUDA_ARCH__
+ #if defined(EIGEN_GPU_COMPILE_PHASE)
return (::isinf)(x);
#elif EIGEN_USE_STD_FPCLASSIFY
using std::isinf;
@@ -743,7 +1012,7 @@
typename internal::enable_if<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>::type
isnan_impl(const T& x)
{
- #ifdef __CUDA_ARCH__
+ #if defined(EIGEN_GPU_COMPILE_PHASE)
return (::isnan)(x);
#elif EIGEN_USE_STD_FPCLASSIFY
using std::isnan;
@@ -800,7 +1069,6 @@
template<typename T> EIGEN_DEVICE_FUNC bool isinf_impl(const std::complex<T>& x);
template<typename T> T generic_fast_tanh_float(const T& a_x);
-
} // end namespace internal
/****************************************************************************
@@ -809,12 +1077,12 @@
namespace numext {
-#ifndef __CUDA_ARCH__
+#if (!defined(EIGEN_GPUCC) || defined(EIGEN_CONSTEXPR_ARE_DEVICE_FUNC))
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE T mini(const T& x, const T& y)
{
- EIGEN_USING_STD_MATH(min);
+ EIGEN_USING_STD(min)
return min EIGEN_NOT_A_MACRO (x,y);
}
@@ -822,7 +1090,7 @@
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE T maxi(const T& x, const T& y)
{
- EIGEN_USING_STD_MATH(max);
+ EIGEN_USING_STD(max)
return max EIGEN_NOT_A_MACRO (x,y);
}
#else
@@ -838,6 +1106,24 @@
{
return fminf(x, y);
}
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE double mini(const double& x, const double& y)
+{
+ return fmin(x, y);
+}
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE long double mini(const long double& x, const long double& y)
+{
+#if defined(EIGEN_HIPCC)
+ // no "fminl" on HIP yet
+ return (x < y) ? x : y;
+#else
+ return fminl(x, y);
+#endif
+}
+
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE T maxi(const T& x, const T& y)
@@ -850,6 +1136,92 @@
{
return fmaxf(x, y);
}
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE double maxi(const double& x, const double& y)
+{
+ return fmax(x, y);
+}
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE long double maxi(const long double& x, const long double& y)
+{
+#if defined(EIGEN_HIPCC)
+ // no "fmaxl" on HIP yet
+ return (x > y) ? x : y;
+#else
+ return fmaxl(x, y);
+#endif
+}
+#endif
+
+#if defined(SYCL_DEVICE_ONLY)
+
+
+#define SYCL_SPECIALIZE_SIGNED_INTEGER_TYPES_BINARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_char) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_short) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_int) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_long)
+#define SYCL_SPECIALIZE_SIGNED_INTEGER_TYPES_UNARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_char) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_short) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_int) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_long)
+#define SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_BINARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_uchar) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_ushort) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_uint) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_ulong)
+#define SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_UNARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_uchar) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_ushort) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_uint) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_ulong)
+#define SYCL_SPECIALIZE_INTEGER_TYPES_BINARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_SIGNED_INTEGER_TYPES_BINARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_BINARY(NAME, FUNC)
+#define SYCL_SPECIALIZE_INTEGER_TYPES_UNARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_SIGNED_INTEGER_TYPES_UNARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_UNARY(NAME, FUNC)
+#define SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, cl::sycl::cl_float) \
+ SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC,cl::sycl::cl_double)
+#define SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(NAME, FUNC) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, cl::sycl::cl_float) \
+ SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC,cl::sycl::cl_double)
+#define SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(NAME, FUNC, RET_TYPE) \
+ SYCL_SPECIALIZE_GEN_UNARY_FUNC(NAME, FUNC, RET_TYPE, cl::sycl::cl_float) \
+ SYCL_SPECIALIZE_GEN_UNARY_FUNC(NAME, FUNC, RET_TYPE, cl::sycl::cl_double)
+
+#define SYCL_SPECIALIZE_GEN_UNARY_FUNC(NAME, FUNC, RET_TYPE, ARG_TYPE) \
+template<> \
+ EIGEN_DEVICE_FUNC \
+ EIGEN_ALWAYS_INLINE RET_TYPE NAME(const ARG_TYPE& x) { \
+ return cl::sycl::FUNC(x); \
+ }
+
+#define SYCL_SPECIALIZE_UNARY_FUNC(NAME, FUNC, TYPE) \
+ SYCL_SPECIALIZE_GEN_UNARY_FUNC(NAME, FUNC, TYPE, TYPE)
+
+#define SYCL_SPECIALIZE_GEN1_BINARY_FUNC(NAME, FUNC, RET_TYPE, ARG_TYPE1, ARG_TYPE2) \
+ template<> \
+ EIGEN_DEVICE_FUNC \
+ EIGEN_ALWAYS_INLINE RET_TYPE NAME(const ARG_TYPE1& x, const ARG_TYPE2& y) { \
+ return cl::sycl::FUNC(x, y); \
+ }
+
+#define SYCL_SPECIALIZE_GEN2_BINARY_FUNC(NAME, FUNC, RET_TYPE, ARG_TYPE) \
+ SYCL_SPECIALIZE_GEN1_BINARY_FUNC(NAME, FUNC, RET_TYPE, ARG_TYPE, ARG_TYPE)
+
+#define SYCL_SPECIALIZE_BINARY_FUNC(NAME, FUNC, TYPE) \
+ SYCL_SPECIALIZE_GEN2_BINARY_FUNC(NAME, FUNC, TYPE, TYPE)
+
+SYCL_SPECIALIZE_INTEGER_TYPES_BINARY(mini, min)
+SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(mini, fmin)
+SYCL_SPECIALIZE_INTEGER_TYPES_BINARY(maxi, max)
+SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(maxi, fmax)
+
#endif
@@ -916,6 +1288,37 @@
return EIGEN_MATHFUNC_IMPL(abs2, Scalar)::run(x);
}
+EIGEN_DEVICE_FUNC
+inline bool abs2(bool x) { return x; }
+
+template<typename T>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE T absdiff(const T& x, const T& y)
+{
+ return x > y ? x - y : y - x;
+}
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE float absdiff(const float& x, const float& y)
+{
+ return fabsf(x - y);
+}
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE double absdiff(const double& x, const double& y)
+{
+ return fabs(x - y);
+}
+
+#if !defined(EIGEN_GPUCC)
+// HIP and CUDA do not support long double.
+template<>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE long double absdiff(const long double& x, const long double& y) {
+ return fabsl(x - y);
+}
+#endif
+
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(norm1, Scalar) norm1(const Scalar& x)
@@ -930,6 +1333,10 @@
return EIGEN_MATHFUNC_IMPL(hypot, Scalar)::run(x, y);
}
+#if defined(SYCL_DEVICE_ONLY)
+ SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(hypot, hypot)
+#endif
+
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(log1p, Scalar) log1p(const Scalar& x)
@@ -937,7 +1344,11 @@
return EIGEN_MATHFUNC_IMPL(log1p, Scalar)::run(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(log1p, log1p)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float log1p(const float &x) { return ::log1pf(x); }
@@ -952,10 +1363,27 @@
return internal::pow_impl<ScalarX,ScalarY>::run(x, y);
}
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(pow, pow)
+#endif
+
template<typename T> EIGEN_DEVICE_FUNC bool (isnan) (const T &x) { return internal::isnan_impl(x); }
template<typename T> EIGEN_DEVICE_FUNC bool (isinf) (const T &x) { return internal::isinf_impl(x); }
template<typename T> EIGEN_DEVICE_FUNC bool (isfinite)(const T &x) { return internal::isfinite_impl(x); }
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isnan, isnan, bool)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isinf, isinf, bool)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isfinite, isfinite, bool)
+#endif
+
+template<typename Scalar>
+EIGEN_DEVICE_FUNC
+inline EIGEN_MATHFUNC_RETVAL(rint, Scalar) rint(const Scalar& x)
+{
+ return EIGEN_MATHFUNC_IMPL(rint, Scalar)::run(x);
+}
+
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(round, Scalar) round(const Scalar& x)
@@ -963,15 +1391,23 @@
return EIGEN_MATHFUNC_IMPL(round, Scalar)::run(x);
}
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(round, round)
+#endif
+
template<typename T>
EIGEN_DEVICE_FUNC
T (floor)(const T& x)
{
- EIGEN_USING_STD_MATH(floor);
+ EIGEN_USING_STD(floor)
return floor(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(floor, floor)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float floor(const float &x) { return ::floorf(x); }
@@ -983,11 +1419,15 @@
EIGEN_DEVICE_FUNC
T (ceil)(const T& x)
{
- EIGEN_USING_STD_MATH(ceil);
+ EIGEN_USING_STD(ceil);
return ceil(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(ceil, ceil)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float ceil(const float &x) { return ::ceilf(x); }
@@ -1020,22 +1460,42 @@
*
* It's usage is justified in performance critical functions, like norm/normalize.
*/
+template<typename Scalar>
+EIGEN_DEVICE_FUNC
+EIGEN_ALWAYS_INLINE EIGEN_MATHFUNC_RETVAL(sqrt, Scalar) sqrt(const Scalar& x)
+{
+ return EIGEN_MATHFUNC_IMPL(sqrt, Scalar)::run(x);
+}
+
+// Boolean specialization, avoids implicit float to bool conversion (-Wimplicit-conversion-floating-point-to-bool).
+template<>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_DEVICE_FUNC
+bool sqrt<bool>(const bool &x) { return x; }
+
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt)
+#endif
+
+/** \returns the reciprocal square root of \a x. **/
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
-T sqrt(const T &x)
+T rsqrt(const T& x)
{
- EIGEN_USING_STD_MATH(sqrt);
- return sqrt(x);
+ return internal::rsqrt_impl<T>::run(x);
}
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) {
- EIGEN_USING_STD_MATH(log);
- return log(x);
+ return internal::log_impl<T>::run(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(log, log)
+#endif
+
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float log(const float &x) { return ::logf(x); }
@@ -1047,7 +1507,7 @@
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename internal::enable_if<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex,typename NumTraits<T>::Real>::type
abs(const T &x) {
- EIGEN_USING_STD_MATH(abs);
+ EIGEN_USING_STD(abs);
return abs(x);
}
@@ -1058,12 +1518,12 @@
return x;
}
-#if defined(__SYCL_DEVICE_ONLY__)
-EIGEN_ALWAYS_INLINE float abs(float x) { return cl::sycl::fabs(x); }
-EIGEN_ALWAYS_INLINE double abs(double x) { return cl::sycl::fabs(x); }
-#endif // defined(__SYCL_DEVICE_ONLY__)
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_INTEGER_TYPES_UNARY(abs, abs)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(abs, fabs)
+#endif
-#ifdef __CUDACC__
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float abs(const float &x) { return ::fabsf(x); }
@@ -1084,26 +1544,69 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T exp(const T &x) {
- EIGEN_USING_STD_MATH(exp);
+ EIGEN_USING_STD(exp);
return exp(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(exp, exp)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float exp(const float &x) { return ::expf(x); }
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
double exp(const double &x) { return ::exp(x); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+std::complex<float> exp(const std::complex<float>& x) {
+ float com = ::expf(x.real());
+ float res_real = com * ::cosf(x.imag());
+ float res_imag = com * ::sinf(x.imag());
+ return std::complex<float>(res_real, res_imag);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+std::complex<double> exp(const std::complex<double>& x) {
+ double com = ::exp(x.real());
+ double res_real = com * ::cos(x.imag());
+ double res_imag = com * ::sin(x.imag());
+ return std::complex<double>(res_real, res_imag);
+}
+#endif
+
+template<typename Scalar>
+EIGEN_DEVICE_FUNC
+inline EIGEN_MATHFUNC_RETVAL(expm1, Scalar) expm1(const Scalar& x)
+{
+ return EIGEN_MATHFUNC_IMPL(expm1, Scalar)::run(x);
+}
+
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(expm1, expm1)
+#endif
+
+#if defined(EIGEN_GPUCC)
+template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+float expm1(const float &x) { return ::expm1f(x); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+double expm1(const double &x) { return ::expm1(x); }
#endif
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T cos(const T &x) {
- EIGEN_USING_STD_MATH(cos);
+ EIGEN_USING_STD(cos);
return cos(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(cos,cos)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float cos(const float &x) { return ::cosf(x); }
@@ -1114,11 +1617,15 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T sin(const T &x) {
- EIGEN_USING_STD_MATH(sin);
+ EIGEN_USING_STD(sin);
return sin(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sin, sin)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float sin(const float &x) { return ::sinf(x); }
@@ -1129,11 +1636,15 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T tan(const T &x) {
- EIGEN_USING_STD_MATH(tan);
+ EIGEN_USING_STD(tan);
return tan(x);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(tan, tan)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float tan(const float &x) { return ::tanf(x); }
@@ -1144,11 +1655,25 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T acos(const T &x) {
- EIGEN_USING_STD_MATH(acos);
+ EIGEN_USING_STD(acos);
return acos(x);
}
-#ifdef __CUDACC__
+#if EIGEN_HAS_CXX11_MATH
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+T acosh(const T &x) {
+ EIGEN_USING_STD(acosh);
+ return static_cast<T>(acosh(x));
+}
+#endif
+
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(acos, acos)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(acosh, acosh)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float acos(const float &x) { return ::acosf(x); }
@@ -1159,11 +1684,25 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T asin(const T &x) {
- EIGEN_USING_STD_MATH(asin);
+ EIGEN_USING_STD(asin);
return asin(x);
}
-#ifdef __CUDACC__
+#if EIGEN_HAS_CXX11_MATH
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+T asinh(const T &x) {
+ EIGEN_USING_STD(asinh);
+ return static_cast<T>(asinh(x));
+}
+#endif
+
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(asin, asin)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(asinh, asinh)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float asin(const float &x) { return ::asinf(x); }
@@ -1174,11 +1713,25 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T atan(const T &x) {
- EIGEN_USING_STD_MATH(atan);
- return atan(x);
+ EIGEN_USING_STD(atan);
+ return static_cast<T>(atan(x));
}
-#ifdef __CUDACC__
+#if EIGEN_HAS_CXX11_MATH
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+T atanh(const T &x) {
+ EIGEN_USING_STD(atanh);
+ return static_cast<T>(atanh(x));
+}
+#endif
+
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(atan, atan)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(atanh, atanh)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float atan(const float &x) { return ::atanf(x); }
@@ -1190,11 +1743,15 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T cosh(const T &x) {
- EIGEN_USING_STD_MATH(cosh);
- return cosh(x);
+ EIGEN_USING_STD(cosh);
+ return static_cast<T>(cosh(x));
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(cosh, cosh)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float cosh(const float &x) { return ::coshf(x); }
@@ -1205,11 +1762,15 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T sinh(const T &x) {
- EIGEN_USING_STD_MATH(sinh);
- return sinh(x);
+ EIGEN_USING_STD(sinh);
+ return static_cast<T>(sinh(x));
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sinh, sinh)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float sinh(const float &x) { return ::sinhf(x); }
@@ -1220,16 +1781,20 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T tanh(const T &x) {
- EIGEN_USING_STD_MATH(tanh);
+ EIGEN_USING_STD(tanh);
return tanh(x);
}
-#if (!defined(__CUDACC__)) && EIGEN_FAST_MATH
+#if (!defined(EIGEN_GPUCC)) && EIGEN_FAST_MATH && !defined(SYCL_DEVICE_ONLY)
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float tanh(float x) { return internal::generic_fast_tanh_float(x); }
#endif
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(tanh, tanh)
+#endif
+
+#if defined(EIGEN_GPUCC)
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float tanh(const float &x) { return ::tanhf(x); }
@@ -1240,11 +1805,15 @@
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T fmod(const T& a, const T& b) {
- EIGEN_USING_STD_MATH(fmod);
+ EIGEN_USING_STD(fmod);
return fmod(a, b);
}
-#ifdef __CUDACC__
+#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(fmod, fmod)
+#endif
+
+#if defined(EIGEN_GPUCC)
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float fmod(const float& a, const float& b) {
@@ -1258,6 +1827,23 @@
}
#endif
+#if defined(SYCL_DEVICE_ONLY)
+#undef SYCL_SPECIALIZE_SIGNED_INTEGER_TYPES_BINARY
+#undef SYCL_SPECIALIZE_SIGNED_INTEGER_TYPES_UNARY
+#undef SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_BINARY
+#undef SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_UNARY
+#undef SYCL_SPECIALIZE_INTEGER_TYPES_BINARY
+#undef SYCL_SPECIALIZE_UNSIGNED_INTEGER_TYPES_UNARY
+#undef SYCL_SPECIALIZE_FLOATING_TYPES_BINARY
+#undef SYCL_SPECIALIZE_FLOATING_TYPES_UNARY
+#undef SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE
+#undef SYCL_SPECIALIZE_GEN_UNARY_FUNC
+#undef SYCL_SPECIALIZE_UNARY_FUNC
+#undef SYCL_SPECIALIZE_GEN1_BINARY_FUNC
+#undef SYCL_SPECIALIZE_GEN2_BINARY_FUNC
+#undef SYCL_SPECIALIZE_BINARY_FUNC
+#endif
+
} // end namespace numext
namespace internal {
@@ -1381,18 +1967,23 @@
{
return random<int>(0,1)==0 ? false : true;
}
+
+ static inline bool run(const bool& a, const bool& b)
+ {
+ return random<int>(a, b)==0 ? false : true;
+ }
};
template<> struct scalar_fuzzy_impl<bool>
{
typedef bool RealScalar;
-
+
template<typename OtherScalar> EIGEN_DEVICE_FUNC
static inline bool isMuchSmallerThan(const bool& x, const bool&, const bool&)
{
return !x;
}
-
+
EIGEN_DEVICE_FUNC
static inline bool isApprox(bool x, bool y, bool)
{
@@ -1404,10 +1995,61 @@
{
return (!x) || y;
}
-
+
};
-
+} // end namespace internal
+
+// Default implementations that rely on other numext implementations
+namespace internal {
+
+// Specialization for complex types that are not supported by std::expm1.
+template <typename RealScalar>
+struct expm1_impl<std::complex<RealScalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(
+ const std::complex<RealScalar>& x) {
+ EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar)
+ RealScalar xr = x.real();
+ RealScalar xi = x.imag();
+ // expm1(z) = exp(z) - 1
+ // = exp(x + i * y) - 1
+ // = exp(x) * (cos(y) + i * sin(y)) - 1
+ // = exp(x) * cos(y) - 1 + i * exp(x) * sin(y)
+ // Imag(expm1(z)) = exp(x) * sin(y)
+ // Real(expm1(z)) = exp(x) * cos(y) - 1
+ // = exp(x) * cos(y) - 1.
+ // = expm1(x) + exp(x) * (cos(y) - 1)
+ // = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2)
+ RealScalar erm1 = numext::expm1<RealScalar>(xr);
+ RealScalar er = erm1 + RealScalar(1.);
+ RealScalar sin2 = numext::sin(xi / RealScalar(2.));
+ sin2 = sin2 * sin2;
+ RealScalar s = numext::sin(xi);
+ RealScalar real_part = erm1 - RealScalar(2.) * er * sin2;
+ return std::complex<RealScalar>(real_part, er * s);
+ }
+};
+
+template<typename T>
+struct rsqrt_impl {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE T run(const T& x) {
+ return T(1)/numext::sqrt(x);
+ }
+};
+
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+template<typename T>
+struct conj_impl<std::complex<T>, true>
+{
+ EIGEN_DEVICE_FUNC
+ static inline std::complex<T> run(const std::complex<T>& x)
+ {
+ return std::complex<T>(numext::real(x), -numext::imag(x));
+ }
+};
+#endif
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctionsImpl.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctionsImpl.h
index 9c1ceb0..4eaaaa7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctionsImpl.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MathFunctionsImpl.h
@@ -17,24 +17,28 @@
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
- is accurate up to a couple of ulp in the range [-9, 9], outside of which
- the tanh(x) = +/-1.
+ is accurate up to a couple of ulps in the (approximate) range [-8, 8],
+ outside of which tanh(x) = +/-1 in single precision. The input is clamped
+ to the range [-c, c]. The value c is chosen as the smallest value where
+ the approximation evaluates to exactly 1. In the reange [-0.0004, 0.0004]
+ the approxmation tanh(x) ~= x is used for better accuracy as x tends to zero.
This implementation works on both scalars and packets.
*/
template<typename T>
T generic_fast_tanh_float(const T& a_x)
{
- // Clamp the inputs to the range [-9, 9] since anything outside
- // this range is +/-1.0f in single-precision.
- const T plus_9 = pset1<T>(9.f);
- const T minus_9 = pset1<T>(-9.f);
- // NOTE GCC prior to 6.3 might improperly optimize this max/min
- // step such that if a_x is nan, x will be either 9 or -9,
- // and tanh will return 1 or -1 instead of nan.
- // This is supposed to be fixed in gcc6.3,
- // see: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
- const T x = pmax(minus_9,pmin(plus_9,a_x));
+ // Clamp the inputs to the range [-c, c]
+#ifdef EIGEN_VECTORIZE_FMA
+ const T plus_clamp = pset1<T>(7.99881172180175781f);
+ const T minus_clamp = pset1<T>(-7.99881172180175781f);
+#else
+ const T plus_clamp = pset1<T>(7.90531110763549805f);
+ const T minus_clamp = pset1<T>(-7.90531110763549805f);
+#endif
+ const T tiny = pset1<T>(0.0004f);
+ const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
+ const T tiny_mask = pcmp_lt(pabs(a_x), tiny);
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_1 = pset1<T>(4.89352455891786e-03f);
const T alpha_3 = pset1<T>(6.37261928875436e-04f);
@@ -62,24 +66,30 @@
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
- // Evaluate the denominator polynomial p.
+ // Evaluate the denominator polynomial q.
T q = pmadd(x2, beta_6, beta_4);
q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
- return pdiv(p, q);
+ return pselect(tiny_mask, x, pdiv(p, q));
}
template<typename RealScalar>
-EIGEN_STRONG_INLINE
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y)
{
- EIGEN_USING_STD_MATH(sqrt);
+ // IEEE IEC 6059 special cases.
+ if ((numext::isinf)(x) || (numext::isinf)(y))
+ return NumTraits<RealScalar>::infinity();
+ if ((numext::isnan)(x) || (numext::isnan)(y))
+ return NumTraits<RealScalar>::quiet_NaN();
+
+ EIGEN_USING_STD(sqrt);
RealScalar p, qp;
p = numext::maxi(x,y);
if(p==RealScalar(0)) return RealScalar(0);
- qp = numext::mini(y,x) / p;
+ qp = numext::mini(y,x) / p;
return p * sqrt(RealScalar(1) + qp*qp);
}
@@ -87,13 +97,102 @@
struct hypot_impl
{
typedef typename NumTraits<Scalar>::Real RealScalar;
- static inline RealScalar run(const Scalar& x, const Scalar& y)
+ static EIGEN_DEVICE_FUNC
+ inline RealScalar run(const Scalar& x, const Scalar& y)
{
- EIGEN_USING_STD_MATH(abs);
+ EIGEN_USING_STD(abs);
return positive_real_hypot<RealScalar>(abs(x), abs(y));
}
};
+// Generic complex sqrt implementation that correctly handles corner cases
+// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
+template<typename T>
+EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
+ // Computes the principal sqrt of the input.
+ //
+ // For a complex square root of the number x + i*y. We want to find real
+ // numbers u and v such that
+ // (u + i*v)^2 = x + i*y <=>
+ // u^2 - v^2 + i*2*u*v = x + i*v.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x
+ // 2*u*v = y.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
+ // v = y / (2 * u)
+ // and for x < 0,
+ // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
+ // u = y / (2 * v)
+ //
+ // Letting w = sqrt(0.5 * (|x| + |z|)),
+ // if x == 0: u = w, v = sign(y) * w
+ // if x > 0: u = w, v = y / (2 * w)
+ // if x < 0: u = |y| / (2 * w), v = sign(y) * w
+
+ const T x = numext::real(z);
+ const T y = numext::imag(z);
+ const T zero = T(0);
+ const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
+
+ return
+ (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
+ : x == zero ? std::complex<T>(w, y < zero ? -w : w)
+ : x > zero ? std::complex<T>(w, y / (2 * w))
+ : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
+}
+
+// Generic complex rsqrt implementation.
+template<typename T>
+EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
+ // Computes the principal reciprocal sqrt of the input.
+ //
+ // For a complex reciprocal square root of the number z = x + i*y. We want to
+ // find real numbers u and v such that
+ // (u + i*v)^2 = 1 / (x + i*y) <=>
+ // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x/|z|^2
+ // 2*u*v = y/|z|^2.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + |z|)) / |z|
+ // v = -y / (2 * u * |z|)
+ // and for x < 0,
+ // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z|
+ // u = -y / (2 * v * |z|)
+ //
+ // Letting w = sqrt(0.5 * (|x| + |z|)),
+ // if x == 0: u = w / |z|, v = -sign(y) * w / |z|
+ // if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
+ // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
+
+ const T x = numext::real(z);
+ const T y = numext::imag(z);
+ const T zero = T(0);
+
+ const T abs_z = numext::hypot(x, y);
+ const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
+ const T woz = w / abs_z;
+ // Corner cases consistent with 1/sqrt(z) on gcc/clang.
+ return
+ abs_z == zero ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
+ : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
+ : x == zero ? std::complex<T>(woz, y < zero ? woz : -woz)
+ : x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
+ : std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz );
+}
+
+template<typename T>
+EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
+ // Computes complex log.
+ T a = numext::abs(z);
+ EIGEN_USING_STD(atan2);
+ T b = atan2(z.imag(), z.real());
+ return std::complex<T>(numext::log(a), b);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Matrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Matrix.h
index 7f4a7af..f0e59a9 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Matrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Matrix.h
@@ -29,7 +29,7 @@
required_alignment = unpacket_traits<PacketScalar>::alignment,
packet_access_bit = (packet_traits<_Scalar>::Vectorizable && (EIGEN_UNALIGNED_VECTORIZE || (actual_alignment>=required_alignment))) ? PacketAccessBit : 0
};
-
+
public:
typedef _Scalar Scalar;
typedef Dense StorageKind;
@@ -44,7 +44,7 @@
Options = _Options,
InnerStrideAtCompileTime = 1,
OuterStrideAtCompileTime = (Options&RowMajor) ? ColsAtCompileTime : RowsAtCompileTime,
-
+
// FIXME, the following flag in only used to define NeedsToAlign in PlainObjectBase
EvaluatorFlags = LinearAccessBit | DirectAccessBit | packet_access_bit | row_major_bit,
Alignment = actual_alignment
@@ -255,53 +255,93 @@
*
* \sa resize(Index,Index)
*/
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Matrix() : Base()
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Matrix() : Base()
{
Base::_check_template_params();
EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
}
// FIXME is it still needed
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
explicit Matrix(internal::constructor_without_unaligned_array_assert)
: Base(internal::constructor_without_unaligned_array_assert())
{ Base::_check_template_params(); EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED }
#if EIGEN_HAS_RVALUE_REFERENCES
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Matrix(Matrix&& other) EIGEN_NOEXCEPT_IF(std::is_nothrow_move_constructible<Scalar>::value)
: Base(std::move(other))
{
Base::_check_template_params();
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Matrix& operator=(Matrix&& other) EIGEN_NOEXCEPT_IF(std::is_nothrow_move_assignable<Scalar>::value)
{
- other.swap(*this);
+ Base::operator=(std::move(other));
return *this;
}
#endif
- #ifndef EIGEN_PARSED_BY_DOXYGEN
+#if EIGEN_HAS_CXX11
+ /** \copydoc PlainObjectBase(const Scalar&, const Scalar&, const Scalar&, const Scalar&, const ArgTypes&... args)
+ *
+ * Example: \include Matrix_variadic_ctor_cxx11.cpp
+ * Output: \verbinclude Matrix_variadic_ctor_cxx11.out
+ *
+ * \sa Matrix(const std::initializer_list<std::initializer_list<Scalar>>&)
+ */
+ template <typename... ArgTypes>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Matrix(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ : Base(a0, a1, a2, a3, args...) {}
+
+ /** \brief Constructs a Matrix and initializes it from the coefficients given as initializer-lists grouped by row. \cpp11
+ *
+ * In the general case, the constructor takes a list of rows, each row being represented as a list of coefficients:
+ *
+ * Example: \include Matrix_initializer_list_23_cxx11.cpp
+ * Output: \verbinclude Matrix_initializer_list_23_cxx11.out
+ *
+ * Each of the inner initializer lists must contain the exact same number of elements, otherwise an assertion is triggered.
+ *
+ * In the case of a compile-time column vector, implicit transposition from a single row is allowed.
+ * Therefore <code>VectorXd{{1,2,3,4,5}}</code> is legal and the more verbose syntax
+ * <code>RowVectorXd{{1},{2},{3},{4},{5}}</code> can be avoided:
+ *
+ * Example: \include Matrix_initializer_list_vector_cxx11.cpp
+ * Output: \verbinclude Matrix_initializer_list_vector_cxx11.out
+ *
+ * In the case of fixed-sized matrices, the initializer list sizes must exactly match the matrix sizes,
+ * and implicit transposition is allowed for compile-time vectors only.
+ *
+ * \sa Matrix(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ */
+ EIGEN_DEVICE_FUNC
+ explicit EIGEN_STRONG_INLINE Matrix(const std::initializer_list<std::initializer_list<Scalar>>& list) : Base(list) {}
+#endif // end EIGEN_HAS_CXX11
+
+#ifndef EIGEN_PARSED_BY_DOXYGEN
// This constructor is for both 1x1 matrices and dynamic vectors
template<typename T>
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE explicit Matrix(const T& x)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit Matrix(const T& x)
{
Base::_check_template_params();
Base::template _init1<T>(x);
}
template<typename T0, typename T1>
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Matrix(const T0& x, const T1& y)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Matrix(const T0& x, const T1& y)
{
Base::_check_template_params();
Base::template _init2<T0,T1>(x, y);
}
- #else
+
+
+#else
/** \brief Constructs a fixed-sized matrix initialized with coefficients starting at \a data */
EIGEN_DEVICE_FUNC
explicit Matrix(const Scalar *data);
@@ -311,7 +351,7 @@
* This is useful for dynamic-size vectors. For fixed-size vectors,
* it is redundant to pass these parameters, so one should use the default constructor
* Matrix() instead.
- *
+ *
* \warning This constructor is disabled for fixed-size \c 1x1 matrices. For instance,
* calling Matrix<double,1,1>(1) will call the initialization constructor: Matrix(const Scalar&).
* For fixed-size \c 1x1 matrices it is therefore recommended to use the default
@@ -319,14 +359,15 @@
* \c EIGEN_INITIALIZE_MATRICES_BY_{ZERO,\c NAN} macros (see \ref TopicPreprocessorDirectives).
*/
EIGEN_STRONG_INLINE explicit Matrix(Index dim);
- /** \brief Constructs an initialized 1x1 matrix with the given coefficient */
+ /** \brief Constructs an initialized 1x1 matrix with the given coefficient
+ * \sa Matrix(const Scalar&, const Scalar&, const Scalar&, const Scalar&, const ArgTypes&...) */
Matrix(const Scalar& x);
/** \brief Constructs an uninitialized matrix with \a rows rows and \a cols columns.
*
* This is useful for dynamic-size matrices. For fixed-size matrices,
* it is redundant to pass these parameters, so one should use the default constructor
* Matrix() instead.
- *
+ *
* \warning This constructor is disabled for fixed-size \c 1x2 and \c 2x1 vectors. For instance,
* calling Matrix2f(2,1) will call the initialization constructor: Matrix(const Scalar& x, const Scalar& y).
* For fixed-size \c 1x2 or \c 2x1 vectors it is therefore recommended to use the default
@@ -335,12 +376,15 @@
*/
EIGEN_DEVICE_FUNC
Matrix(Index rows, Index cols);
-
- /** \brief Constructs an initialized 2D vector with given coefficients */
- Matrix(const Scalar& x, const Scalar& y);
- #endif
- /** \brief Constructs an initialized 3D vector with given coefficients */
+ /** \brief Constructs an initialized 2D vector with given coefficients
+ * \sa Matrix(const Scalar&, const Scalar&, const Scalar&, const Scalar&, const ArgTypes&...) */
+ Matrix(const Scalar& x, const Scalar& y);
+ #endif // end EIGEN_PARSED_BY_DOXYGEN
+
+ /** \brief Constructs an initialized 3D vector with given coefficients
+ * \sa Matrix(const Scalar&, const Scalar&, const Scalar&, const Scalar&, const ArgTypes&...)
+ */
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Matrix(const Scalar& x, const Scalar& y, const Scalar& z)
{
@@ -350,7 +394,9 @@
m_storage.data()[1] = y;
m_storage.data()[2] = z;
}
- /** \brief Constructs an initialized 4D vector with given coefficients */
+ /** \brief Constructs an initialized 4D vector with given coefficients
+ * \sa Matrix(const Scalar&, const Scalar&, const Scalar&, const Scalar&, const ArgTypes&...)
+ */
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Matrix(const Scalar& x, const Scalar& y, const Scalar& z, const Scalar& w)
{
@@ -377,8 +423,10 @@
: Base(other.derived())
{ }
- EIGEN_DEVICE_FUNC inline Index innerStride() const { return 1; }
- EIGEN_DEVICE_FUNC inline Index outerStride() const { return this->innerSize(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT { return 1; }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return this->innerSize(); }
/////////// Geometry module ///////////
@@ -405,7 +453,7 @@
*
* \ingroup Core_Module
*
- * Eigen defines several typedef shortcuts for most common matrix and vector types.
+ * %Eigen defines several typedef shortcuts for most common matrix and vector types.
*
* The general patterns are the following:
*
@@ -418,6 +466,15 @@
* There are also \c VectorSizeType and \c RowVectorSizeType which are self-explanatory. For example, \c Vector4cf is
* a fixed-size vector of 4 complex floats.
*
+ * With \cpp11, template alias are also defined for common sizes.
+ * They follow the same pattern as above except that the scalar type suffix is replaced by a
+ * template parameter, i.e.:
+ * - `MatrixSize<Type>` where `Size` can be \c 2,\c 3,\c 4 for fixed size square matrices or \c X for dynamic size.
+ * - `MatrixXSize<Type>` and `MatrixSizeX<Type>` where `Size` can be \c 2,\c 3,\c 4 for hybrid dynamic/fixed matrices.
+ * - `VectorSize<Type>` and `RowVectorSize<Type>` for column and row vectors.
+ *
+ * With \cpp11, you can also use fully generic column and row vector types: `Vector<Type,Size>` and `RowVector<Type,Size>`.
+ *
* \sa class Matrix
*/
@@ -454,6 +511,55 @@
#undef EIGEN_MAKE_TYPEDEFS
#undef EIGEN_MAKE_FIXED_TYPEDEFS
+#if EIGEN_HAS_CXX11
+
+#define EIGEN_MAKE_TYPEDEFS(Size, SizeSuffix) \
+/** \ingroup matrixtypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Matrix##SizeSuffix = Matrix<Type, Size, Size>; \
+/** \ingroup matrixtypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Vector##SizeSuffix = Matrix<Type, Size, 1>; \
+/** \ingroup matrixtypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using RowVector##SizeSuffix = Matrix<Type, 1, Size>;
+
+#define EIGEN_MAKE_FIXED_TYPEDEFS(Size) \
+/** \ingroup matrixtypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Matrix##Size##X = Matrix<Type, Size, Dynamic>; \
+/** \ingroup matrixtypedefs */ \
+/** \brief \cpp11 */ \
+template <typename Type> \
+using Matrix##X##Size = Matrix<Type, Dynamic, Size>;
+
+EIGEN_MAKE_TYPEDEFS(2, 2)
+EIGEN_MAKE_TYPEDEFS(3, 3)
+EIGEN_MAKE_TYPEDEFS(4, 4)
+EIGEN_MAKE_TYPEDEFS(Dynamic, X)
+EIGEN_MAKE_FIXED_TYPEDEFS(2)
+EIGEN_MAKE_FIXED_TYPEDEFS(3)
+EIGEN_MAKE_FIXED_TYPEDEFS(4)
+
+/** \ingroup matrixtypedefs
+ * \brief \cpp11 */
+template <typename Type, int Size>
+using Vector = Matrix<Type, Size, 1>;
+
+/** \ingroup matrixtypedefs
+ * \brief \cpp11 */
+template <typename Type, int Size>
+using RowVector = Matrix<Type, 1, Size>;
+
+#undef EIGEN_MAKE_TYPEDEFS
+#undef EIGEN_MAKE_FIXED_TYPEDEFS
+
+#endif // EIGEN_HAS_CXX11
+
} // end namespace Eigen
#endif // EIGEN_MATRIX_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MatrixBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MatrixBase.h
index f8bcc8c..45c3a59 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MatrixBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/MatrixBase.h
@@ -76,6 +76,7 @@
using Base::coeffRef;
using Base::lazyAssign;
using Base::eval;
+ using Base::operator-;
using Base::operator+=;
using Base::operator-=;
using Base::operator*=;
@@ -122,7 +123,6 @@
#define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::MatrixBase
#define EIGEN_DOC_UNARY_ADDONS(X,Y)
-# include "../plugins/CommonCwiseUnaryOps.h"
# include "../plugins/CommonCwiseBinaryOps.h"
# include "../plugins/MatrixCwiseUnaryOps.h"
# include "../plugins/MatrixCwiseBinaryOps.h"
@@ -268,6 +268,8 @@
Derived& setIdentity();
EIGEN_DEVICE_FUNC
Derived& setIdentity(Index rows, Index cols);
+ EIGEN_DEVICE_FUNC Derived& setUnit(Index i);
+ EIGEN_DEVICE_FUNC Derived& setUnit(Index newSize, Index i);
bool isIdentity(const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
bool isDiagonal(const RealScalar& prec = NumTraits<Scalar>::dummy_precision()) const;
@@ -296,7 +298,7 @@
EIGEN_DEVICE_FUNC inline bool operator!=(const MatrixBase<OtherDerived>& other) const
{ return cwiseNotEqual(other).any(); }
- NoAlias<Derived,Eigen::MatrixBase > noalias();
+ NoAlias<Derived,Eigen::MatrixBase > EIGEN_DEVICE_FUNC noalias();
// TODO forceAlignedAccess is temporarily disabled
// Need to find a nicer workaround.
@@ -326,6 +328,7 @@
inline const PartialPivLU<PlainObject> lu() const;
+ EIGEN_DEVICE_FUNC
inline const Inverse<Derived> inverse() const;
template<typename ResultType>
@@ -335,12 +338,15 @@
bool& invertible,
const RealScalar& absDeterminantThreshold = NumTraits<Scalar>::dummy_precision()
) const;
+
template<typename ResultType>
inline void computeInverseWithCheck(
ResultType& inverse,
bool& invertible,
const RealScalar& absDeterminantThreshold = NumTraits<Scalar>::dummy_precision()
) const;
+
+ EIGEN_DEVICE_FUNC
Scalar determinant() const;
/////////// Cholesky module ///////////
@@ -412,15 +418,19 @@
////////// Householder module ///////////
+ EIGEN_DEVICE_FUNC
void makeHouseholderInPlace(Scalar& tau, RealScalar& beta);
template<typename EssentialPart>
+ EIGEN_DEVICE_FUNC
void makeHouseholder(EssentialPart& essential,
Scalar& tau, RealScalar& beta) const;
template<typename EssentialPart>
+ EIGEN_DEVICE_FUNC
void applyHouseholderOnTheLeft(const EssentialPart& essential,
const Scalar& tau,
Scalar* workspace);
template<typename EssentialPart>
+ EIGEN_DEVICE_FUNC
void applyHouseholderOnTheRight(const EssentialPart& essential,
const Scalar& tau,
Scalar* workspace);
@@ -428,8 +438,10 @@
///////// Jacobi module /////////
template<typename OtherScalar>
+ EIGEN_DEVICE_FUNC
void applyOnTheLeft(Index p, Index q, const JacobiRotation<OtherScalar>& j);
template<typename OtherScalar>
+ EIGEN_DEVICE_FUNC
void applyOnTheRight(Index p, Index q, const JacobiRotation<OtherScalar>& j);
///////// SparseCore module /////////
@@ -456,6 +468,11 @@
const MatrixFunctionReturnValue<Derived> matrixFunction(StemFunction f) const;
EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, cosh, hyperbolic cosine)
EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, sinh, hyperbolic sine)
+#if EIGEN_HAS_CXX11_MATH
+ EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, atanh, inverse hyperbolic cosine)
+ EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, acosh, inverse hyperbolic cosine)
+ EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, asinh, inverse hyperbolic sine)
+#endif
EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, cos, cosine)
EIGEN_MATRIX_FUNCTION(MatrixFunctionReturnValue, sin, sine)
EIGEN_MATRIX_FUNCTION(MatrixSquareRootReturnValue, sqrt, square root)
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NestByValue.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NestByValue.h
index 13adf07..b427576 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NestByValue.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NestByValue.h
@@ -16,7 +16,11 @@
namespace internal {
template<typename ExpressionType>
struct traits<NestByValue<ExpressionType> > : public traits<ExpressionType>
-{};
+{
+ enum {
+ Flags = traits<ExpressionType>::Flags & ~NestByRefBit
+ };
+};
}
/** \class NestByValue
@@ -41,57 +45,13 @@
EIGEN_DEVICE_FUNC explicit inline NestByValue(const ExpressionType& matrix) : m_expression(matrix) {}
- EIGEN_DEVICE_FUNC inline Index rows() const { return m_expression.rows(); }
- EIGEN_DEVICE_FUNC inline Index cols() const { return m_expression.cols(); }
- EIGEN_DEVICE_FUNC inline Index outerStride() const { return m_expression.outerStride(); }
- EIGEN_DEVICE_FUNC inline Index innerStride() const { return m_expression.innerStride(); }
-
- EIGEN_DEVICE_FUNC inline const CoeffReturnType coeff(Index row, Index col) const
- {
- return m_expression.coeff(row, col);
- }
-
- EIGEN_DEVICE_FUNC inline Scalar& coeffRef(Index row, Index col)
- {
- return m_expression.const_cast_derived().coeffRef(row, col);
- }
-
- EIGEN_DEVICE_FUNC inline const CoeffReturnType coeff(Index index) const
- {
- return m_expression.coeff(index);
- }
-
- EIGEN_DEVICE_FUNC inline Scalar& coeffRef(Index index)
- {
- return m_expression.const_cast_derived().coeffRef(index);
- }
-
- template<int LoadMode>
- inline const PacketScalar packet(Index row, Index col) const
- {
- return m_expression.template packet<LoadMode>(row, col);
- }
-
- template<int LoadMode>
- inline void writePacket(Index row, Index col, const PacketScalar& x)
- {
- m_expression.const_cast_derived().template writePacket<LoadMode>(row, col, x);
- }
-
- template<int LoadMode>
- inline const PacketScalar packet(Index index) const
- {
- return m_expression.template packet<LoadMode>(index);
- }
-
- template<int LoadMode>
- inline void writePacket(Index index, const PacketScalar& x)
- {
- m_expression.const_cast_derived().template writePacket<LoadMode>(index, x);
- }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index rows() const EIGEN_NOEXCEPT { return m_expression.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index cols() const EIGEN_NOEXCEPT { return m_expression.cols(); }
EIGEN_DEVICE_FUNC operator const ExpressionType&() const { return m_expression; }
+ EIGEN_DEVICE_FUNC const ExpressionType& nestedExpression() const { return m_expression; }
+
protected:
const ExpressionType m_expression;
};
@@ -99,12 +59,27 @@
/** \returns an expression of the temporary version of *this.
*/
template<typename Derived>
-inline const NestByValue<Derived>
+EIGEN_DEVICE_FUNC inline const NestByValue<Derived>
DenseBase<Derived>::nestByValue() const
{
return NestByValue<Derived>(derived());
}
+namespace internal {
+
+// Evaluator of Solve -> eval into a temporary
+template<typename ArgType>
+struct evaluator<NestByValue<ArgType> >
+ : public evaluator<ArgType>
+{
+ typedef evaluator<ArgType> Base;
+
+ EIGEN_DEVICE_FUNC explicit evaluator(const NestByValue<ArgType>& xpr)
+ : Base(xpr.nestedExpression())
+ {}
+};
+}
+
} // end namespace Eigen
#endif // EIGEN_NESTBYVALUE_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NoAlias.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NoAlias.h
index 3390801..570283d 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NoAlias.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NoAlias.h
@@ -33,6 +33,7 @@
public:
typedef typename ExpressionType::Scalar Scalar;
+ EIGEN_DEVICE_FUNC
explicit NoAlias(ExpressionType& expression) : m_expression(expression) {}
template<typename OtherDerived>
@@ -74,10 +75,10 @@
*
* More precisely, noalias() allows to bypass the EvalBeforeAssignBit flag.
* Currently, even though several expressions may alias, only product
- * expressions have this flag. Therefore, noalias() is only usefull when
+ * expressions have this flag. Therefore, noalias() is only useful when
* the source expression contains a matrix product.
*
- * Here are some examples where noalias is usefull:
+ * Here are some examples where noalias is useful:
* \code
* D.noalias() = A * B;
* D.noalias() += A.transpose() * B;
@@ -98,7 +99,7 @@
* \sa class NoAlias
*/
template<typename Derived>
-NoAlias<Derived,MatrixBase> MatrixBase<Derived>::noalias()
+NoAlias<Derived,MatrixBase> EIGEN_DEVICE_FUNC MatrixBase<Derived>::noalias()
{
return NoAlias<Derived, Eigen::MatrixBase >(derived());
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NumTraits.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NumTraits.h
index daf4898..72eac5a 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NumTraits.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/NumTraits.h
@@ -21,12 +21,14 @@
bool is_integer = NumTraits<T>::IsInteger>
struct default_digits10_impl
{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int run() { return std::numeric_limits<T>::digits10; }
};
template<typename T>
struct default_digits10_impl<T,false,false> // Floating point
{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int run() {
using std::log10;
using std::ceil;
@@ -38,11 +40,64 @@
template<typename T>
struct default_digits10_impl<T,false,true> // Integer
{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static int run() { return 0; }
+};
+
+
+// default implementation of digits(), based on numeric_limits if specialized,
+// 0 for integer types, and log2(epsilon()) otherwise.
+template< typename T,
+ bool use_numeric_limits = std::numeric_limits<T>::is_specialized,
+ bool is_integer = NumTraits<T>::IsInteger>
+struct default_digits_impl
+{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static int run() { return std::numeric_limits<T>::digits; }
+};
+
+template<typename T>
+struct default_digits_impl<T,false,false> // Floating point
+{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static int run() {
+ using std::log;
+ using std::ceil;
+ typedef typename NumTraits<T>::Real Real;
+ return int(ceil(-log(NumTraits<Real>::epsilon())/log(static_cast<Real>(2))));
+ }
+};
+
+template<typename T>
+struct default_digits_impl<T,false,true> // Integer
+{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int run() { return 0; }
};
} // end namespace internal
+namespace numext {
+/** \internal bit-wise cast without changing the underlying bit representation. */
+
+// TODO: Replace by std::bit_cast (available in C++20)
+template <typename Tgt, typename Src>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Tgt bit_cast(const Src& src) {
+#if EIGEN_HAS_TYPE_TRAITS
+ // The behaviour of memcpy is not specified for non-trivially copyable types
+ EIGEN_STATIC_ASSERT(std::is_trivially_copyable<Src>::value, THIS_TYPE_IS_NOT_SUPPORTED);
+ EIGEN_STATIC_ASSERT(std::is_trivially_copyable<Tgt>::value && std::is_default_constructible<Tgt>::value,
+ THIS_TYPE_IS_NOT_SUPPORTED);
+#endif
+
+ EIGEN_STATIC_ASSERT(sizeof(Src) == sizeof(Tgt), THIS_TYPE_IS_NOT_SUPPORTED);
+ Tgt tgt;
+ EIGEN_USING_STD(memcpy)
+ memcpy(&tgt, &src, sizeof(Tgt));
+ return tgt;
+}
+} // namespace numext
+
/** \class NumTraits
* \ingroup Core_Module
*
@@ -71,7 +126,7 @@
* and to \c 0 otherwise.
* \li Enum values ReadCost, AddCost and MulCost representing a rough estimate of the number of CPU cycles needed
* to by move / add / mul instructions respectively, assuming the data is already stored in CPU registers.
- * Stay vague here. No need to do architecture-specific stuff.
+ * Stay vague here. No need to do architecture-specific stuff. If you don't know what this means, just use \c Eigen::HugeCost.
* \li An enum value \a IsSigned. It is equal to \c 1 if \a T is a signed type and to 0 if \a T is unsigned.
* \li An enum value \a RequireInitialization. It is equal to \c 1 if the constructor of the numeric type \a T must
* be called, and to 0 if it is safe not to call it. Default is 0 if \a T is an arithmetic type, and 1 otherwise.
@@ -80,9 +135,18 @@
* \li A dummy_precision() function returning a weak epsilon value. It is mainly used as a default
* value by the fuzzy comparison operators.
* \li highest() and lowest() functions returning the highest and lowest possible values respectively.
+ * \li digits() function returning the number of radix digits (non-sign digits for integers, mantissa for floating-point). This is
+ * the analogue of <a href="http://en.cppreference.com/w/cpp/types/numeric_limits/digits">std::numeric_limits<T>::digits</a>
+ * which is used as the default implementation if specialized.
* \li digits10() function returning the number of decimal digits that can be represented without change. This is
* the analogue of <a href="http://en.cppreference.com/w/cpp/types/numeric_limits/digits10">std::numeric_limits<T>::digits10</a>
* which is used as the default implementation if specialized.
+ * \li min_exponent() and max_exponent() functions returning the highest and lowest possible values, respectively,
+ * such that the radix raised to the power exponent-1 is a normalized floating-point number. These are equivalent to
+ * <a href="http://en.cppreference.com/w/cpp/types/numeric_limits/min_exponent">std::numeric_limits<T>::min_exponent</a>/
+ * <a href="http://en.cppreference.com/w/cpp/types/numeric_limits/max_exponent">std::numeric_limits<T>::max_exponent</a>.
+ * \li infinity() function returning a representation of positive infinity, if available.
+ * \li quiet_NaN function returning a non-signaling "not-a-number", if available.
*/
template<typename T> struct GenericNumTraits
@@ -106,42 +170,60 @@
typedef T Nested;
typedef T Literal;
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline Real epsilon()
{
return numext::numeric_limits<T>::epsilon();
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline int digits10()
{
return internal::default_digits10_impl<T>::run();
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static inline int digits()
+ {
+ return internal::default_digits_impl<T>::run();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static inline int min_exponent()
+ {
+ return numext::numeric_limits<T>::min_exponent;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static inline int max_exponent()
+ {
+ return numext::numeric_limits<T>::max_exponent;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline Real dummy_precision()
{
// make sure to override this for floating-point types
return Real(0);
}
-
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline T highest() {
return (numext::numeric_limits<T>::max)();
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline T lowest() {
- return IsInteger ? (numext::numeric_limits<T>::min)() : (-(numext::numeric_limits<T>::max)());
+ return IsInteger ? (numext::numeric_limits<T>::min)()
+ : static_cast<T>(-(numext::numeric_limits<T>::max)());
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline T infinity() {
return numext::numeric_limits<T>::infinity();
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline T quiet_NaN() {
return numext::numeric_limits<T>::quiet_NaN();
}
@@ -153,19 +235,20 @@
template<> struct NumTraits<float>
: GenericNumTraits<float>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline float dummy_precision() { return 1e-5f; }
};
template<> struct NumTraits<double> : GenericNumTraits<double>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline double dummy_precision() { return 1e-12; }
};
template<> struct NumTraits<long double>
: GenericNumTraits<long double>
{
+ EIGEN_CONSTEXPR
static inline long double dummy_precision() { return 1e-15l; }
};
@@ -182,11 +265,11 @@
MulCost = 4 * NumTraits<Real>::MulCost + 2 * NumTraits<Real>::AddCost
};
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline Real epsilon() { return NumTraits<Real>::epsilon(); }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline Real dummy_precision() { return NumTraits<Real>::dummy_precision(); }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline int digits10() { return NumTraits<Real>::digits10(); }
};
@@ -206,16 +289,17 @@
IsInteger = NumTraits<Scalar>::IsInteger,
IsSigned = NumTraits<Scalar>::IsSigned,
RequireInitialization = 1,
- ReadCost = ArrayType::SizeAtCompileTime==Dynamic ? HugeCost : ArrayType::SizeAtCompileTime * NumTraits<Scalar>::ReadCost,
- AddCost = ArrayType::SizeAtCompileTime==Dynamic ? HugeCost : ArrayType::SizeAtCompileTime * NumTraits<Scalar>::AddCost,
- MulCost = ArrayType::SizeAtCompileTime==Dynamic ? HugeCost : ArrayType::SizeAtCompileTime * NumTraits<Scalar>::MulCost
+ ReadCost = ArrayType::SizeAtCompileTime==Dynamic ? HugeCost : ArrayType::SizeAtCompileTime * int(NumTraits<Scalar>::ReadCost),
+ AddCost = ArrayType::SizeAtCompileTime==Dynamic ? HugeCost : ArrayType::SizeAtCompileTime * int(NumTraits<Scalar>::AddCost),
+ MulCost = ArrayType::SizeAtCompileTime==Dynamic ? HugeCost : ArrayType::SizeAtCompileTime * int(NumTraits<Scalar>::MulCost)
};
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline RealScalar epsilon() { return NumTraits<RealScalar>::epsilon(); }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static inline RealScalar dummy_precision() { return NumTraits<RealScalar>::dummy_precision(); }
+ EIGEN_CONSTEXPR
static inline int digits10() { return NumTraits<Scalar>::digits10(); }
};
@@ -229,6 +313,7 @@
MulCost = HugeCost
};
+ EIGEN_CONSTEXPR
static inline int digits10() { return 0; }
private:
@@ -243,6 +328,8 @@
// Empty specialization for void to allow template specialization based on NumTraits<T>::Real with T==void and SFINAE.
template<> struct NumTraits<void> {};
+template<> struct NumTraits<bool> : GenericNumTraits<bool> {};
+
} // end namespace Eigen
#endif // EIGEN_NUMTRAITS_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PartialReduxEvaluator.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PartialReduxEvaluator.h
new file mode 100644
index 0000000..29abf35
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PartialReduxEvaluator.h
@@ -0,0 +1,232 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2011-2018 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_PARTIALREDUX_H
+#define EIGEN_PARTIALREDUX_H
+
+namespace Eigen {
+
+namespace internal {
+
+
+/***************************************************************************
+*
+* This file provides evaluators for partial reductions.
+* There are two modes:
+*
+* - scalar path: simply calls the respective function on the column or row.
+* -> nothing special here, all the tricky part is handled by the return
+* types of VectorwiseOp's members. They embed the functor calling the
+* respective DenseBase's member function.
+*
+* - vectorized path: implements a packet-wise reductions followed by
+* some (optional) processing of the outcome, e.g., division by n for mean.
+*
+* For the vectorized path let's observe that the packet-size and outer-unrolling
+* are both decided by the assignement logic. So all we have to do is to decide
+* on the inner unrolling.
+*
+* For the unrolling, we can reuse "internal::redux_vec_unroller" from Redux.h,
+* but be need to be careful to specify correct increment.
+*
+***************************************************************************/
+
+
+/* logic deciding a strategy for unrolling of vectorized paths */
+template<typename Func, typename Evaluator>
+struct packetwise_redux_traits
+{
+ enum {
+ OuterSize = int(Evaluator::IsRowMajor) ? Evaluator::RowsAtCompileTime : Evaluator::ColsAtCompileTime,
+ Cost = OuterSize == Dynamic ? HugeCost
+ : OuterSize * Evaluator::CoeffReadCost + (OuterSize-1) * functor_traits<Func>::Cost,
+ Unrolling = Cost <= EIGEN_UNROLLING_LIMIT ? CompleteUnrolling : NoUnrolling
+ };
+
+};
+
+/* Value to be returned when size==0 , by default let's return 0 */
+template<typename PacketType,typename Func>
+EIGEN_DEVICE_FUNC
+PacketType packetwise_redux_empty_value(const Func& ) { return pset1<PacketType>(0); }
+
+/* For products the default is 1 */
+template<typename PacketType,typename Scalar>
+EIGEN_DEVICE_FUNC
+PacketType packetwise_redux_empty_value(const scalar_product_op<Scalar,Scalar>& ) { return pset1<PacketType>(1); }
+
+/* Perform the actual reduction */
+template<typename Func, typename Evaluator,
+ int Unrolling = packetwise_redux_traits<Func, Evaluator>::Unrolling
+>
+struct packetwise_redux_impl;
+
+/* Perform the actual reduction with unrolling */
+template<typename Func, typename Evaluator>
+struct packetwise_redux_impl<Func, Evaluator, CompleteUnrolling>
+{
+ typedef redux_novec_unroller<Func,Evaluator, 0, Evaluator::SizeAtCompileTime> Base;
+ typedef typename Evaluator::Scalar Scalar;
+
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE
+ PacketType run(const Evaluator &eval, const Func& func, Index /*size*/)
+ {
+ return redux_vec_unroller<Func, Evaluator, 0, packetwise_redux_traits<Func, Evaluator>::OuterSize>::template run<PacketType>(eval,func);
+ }
+};
+
+/* Add a specialization of redux_vec_unroller for size==0 at compiletime.
+ * This specialization is not required for general reductions, which is
+ * why it is defined here.
+ */
+template<typename Func, typename Evaluator, int Start>
+struct redux_vec_unroller<Func, Evaluator, Start, 0>
+{
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE PacketType run(const Evaluator &, const Func& f)
+ {
+ return packetwise_redux_empty_value<PacketType>(f);
+ }
+};
+
+/* Perform the actual reduction for dynamic sizes */
+template<typename Func, typename Evaluator>
+struct packetwise_redux_impl<Func, Evaluator, NoUnrolling>
+{
+ typedef typename Evaluator::Scalar Scalar;
+ typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
+
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC
+ static PacketType run(const Evaluator &eval, const Func& func, Index size)
+ {
+ if(size==0)
+ return packetwise_redux_empty_value<PacketType>(func);
+
+ const Index size4 = (size-1)&(~3);
+ PacketType p = eval.template packetByOuterInner<Unaligned,PacketType>(0,0);
+ Index i = 1;
+ // This loop is optimized for instruction pipelining:
+ // - each iteration generates two independent instructions
+ // - thanks to branch prediction and out-of-order execution we have independent instructions across loops
+ for(; i<size4; i+=4)
+ p = func.packetOp(p,
+ func.packetOp(
+ func.packetOp(eval.template packetByOuterInner<Unaligned,PacketType>(i+0,0),eval.template packetByOuterInner<Unaligned,PacketType>(i+1,0)),
+ func.packetOp(eval.template packetByOuterInner<Unaligned,PacketType>(i+2,0),eval.template packetByOuterInner<Unaligned,PacketType>(i+3,0))));
+ for(; i<size; ++i)
+ p = func.packetOp(p, eval.template packetByOuterInner<Unaligned,PacketType>(i,0));
+ return p;
+ }
+};
+
+template< typename ArgType, typename MemberOp, int Direction>
+struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
+ : evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
+{
+ typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
+ typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
+ typedef typename internal::add_const_on_value_type<ArgTypeNested>::type ConstArgTypeNested;
+ typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
+ typedef typename ArgType::Scalar InputScalar;
+ typedef typename XprType::Scalar Scalar;
+ enum {
+ TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(ArgType::ColsAtCompileTime)
+ };
+ typedef typename MemberOp::template Cost<int(TraversalSize)> CostOpType;
+ enum {
+ CoeffReadCost = TraversalSize==Dynamic ? HugeCost
+ : TraversalSize==0 ? 1
+ : int(TraversalSize) * int(evaluator<ArgType>::CoeffReadCost) + int(CostOpType::value),
+
+ _ArgFlags = evaluator<ArgType>::Flags,
+
+ _Vectorizable = bool(int(_ArgFlags)&PacketAccessBit)
+ && bool(MemberOp::Vectorizable)
+ && (Direction==int(Vertical) ? bool(_ArgFlags&RowMajorBit) : (_ArgFlags&RowMajorBit)==0)
+ && (TraversalSize!=0),
+
+ Flags = (traits<XprType>::Flags&RowMajorBit)
+ | (evaluator<ArgType>::Flags&(HereditaryBits&(~RowMajorBit)))
+ | (_Vectorizable ? PacketAccessBit : 0)
+ | LinearAccessBit,
+
+ Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
+ };
+
+ EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
+ : m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
+ {
+ EIGEN_INTERNAL_CHECK_COST_VALUE(TraversalSize==Dynamic ? HugeCost : (TraversalSize==0 ? 1 : int(CostOpType::value)));
+ EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
+ }
+
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Scalar coeff(Index i, Index j) const
+ {
+ return coeff(Direction==Vertical ? j : i);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Scalar coeff(Index index) const
+ {
+ return m_functor(m_arg.template subVector<DirectionType(Direction)>(index));
+ }
+
+ template<int LoadMode,typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ PacketType packet(Index i, Index j) const
+ {
+ return packet<LoadMode,PacketType>(Direction==Vertical ? j : i);
+ }
+
+ template<int LoadMode,typename PacketType>
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+ PacketType packet(Index idx) const
+ {
+ enum { PacketSize = internal::unpacket_traits<PacketType>::size };
+ typedef Block<const ArgTypeNestedCleaned,
+ Direction==Vertical ? int(ArgType::RowsAtCompileTime) : int(PacketSize),
+ Direction==Vertical ? int(PacketSize) : int(ArgType::ColsAtCompileTime),
+ true /* InnerPanel */> PanelType;
+
+ PanelType panel(m_arg,
+ Direction==Vertical ? 0 : idx,
+ Direction==Vertical ? idx : 0,
+ Direction==Vertical ? m_arg.rows() : Index(PacketSize),
+ Direction==Vertical ? Index(PacketSize) : m_arg.cols());
+
+ // FIXME
+ // See bug 1612, currently if PacketSize==1 (i.e. complex<double> with 128bits registers) then the storage-order of panel get reversed
+ // and methods like packetByOuterInner do not make sense anymore in this context.
+ // So let's just by pass "vectorization" in this case:
+ if(PacketSize==1)
+ return internal::pset1<PacketType>(coeff(idx));
+
+ typedef typename internal::redux_evaluator<PanelType> PanelEvaluator;
+ PanelEvaluator panel_eval(panel);
+ typedef typename MemberOp::BinaryOp BinaryOp;
+ PacketType p = internal::packetwise_redux_impl<BinaryOp,PanelEvaluator>::template run<PacketType>(panel_eval,m_functor.binaryFunc(),m_arg.outerSize());
+ return p;
+ }
+
+protected:
+ ConstArgTypeNested m_arg;
+ const MemberOp m_functor;
+};
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_PARTIALREDUX_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PermutationMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PermutationMatrix.h
index b1fb455..69401bf 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PermutationMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PermutationMatrix.h
@@ -87,25 +87,14 @@
return derived();
}
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** This is a special case of the templated operator=. Its purpose is to
- * prevent a default operator= from hiding the templated operator=.
- */
- Derived& operator=(const PermutationBase& other)
- {
- indices() = other.indices();
- return derived();
- }
- #endif
-
/** \returns the number of rows */
- inline Index rows() const { return Index(indices().size()); }
+ inline EIGEN_DEVICE_FUNC Index rows() const { return Index(indices().size()); }
/** \returns the number of columns */
- inline Index cols() const { return Index(indices().size()); }
+ inline EIGEN_DEVICE_FUNC Index cols() const { return Index(indices().size()); }
/** \returns the size of a side of the respective square matrix, i.e., the number of indices */
- inline Index size() const { return Index(indices().size()); }
+ inline EIGEN_DEVICE_FUNC Index size() const { return Index(indices().size()); }
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename DenseDerived>
@@ -333,12 +322,6 @@
inline PermutationMatrix(const PermutationBase<OtherDerived>& other)
: m_indices(other.indices()) {}
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** Standard copy constructor. Defined only to prevent a default copy constructor
- * from hiding the other templated constructor */
- inline PermutationMatrix(const PermutationMatrix& other) : m_indices(other.indices()) {}
- #endif
-
/** Generic constructor from expression of the indices. The indices
* array has the meaning that the permutations sends each integer i to indices[i].
*
@@ -373,17 +356,6 @@
return Base::operator=(tr.derived());
}
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** This is a special case of the templated operator=. Its purpose is to
- * prevent a default operator= from hiding the templated operator=.
- */
- PermutationMatrix& operator=(const PermutationMatrix& other)
- {
- m_indices = other.m_indices;
- return *this;
- }
- #endif
-
/** const version of indices(). */
const IndicesType& indices() const { return m_indices; }
/** \returns a reference to the stored array representing the permutation. */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PlainObjectBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PlainObjectBase.h
index 1dc7e22..e2ddbd1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PlainObjectBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/PlainObjectBase.h
@@ -13,10 +13,10 @@
#if defined(EIGEN_INITIALIZE_MATRICES_BY_ZERO)
# define EIGEN_INITIALIZE_COEFFS
-# define EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED for(int i=0;i<base().size();++i) coeffRef(i)=Scalar(0);
+# define EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED for(Index i=0;i<base().size();++i) coeffRef(i)=Scalar(0);
#elif defined(EIGEN_INITIALIZE_MATRICES_BY_NAN)
# define EIGEN_INITIALIZE_COEFFS
-# define EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED for(int i=0;i<base().size();++i) coeffRef(i)=std::numeric_limits<Scalar>::quiet_NaN();
+# define EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED for(Index i=0;i<base().size();++i) coeffRef(i)=std::numeric_limits<Scalar>::quiet_NaN();
#else
# undef EIGEN_INITIALIZE_COEFFS
# define EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
@@ -104,7 +104,7 @@
typedef typename internal::traits<Derived>::StorageKind StorageKind;
typedef typename internal::traits<Derived>::Scalar Scalar;
-
+
typedef typename internal::packet_traits<Scalar>::type PacketScalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Derived DenseType;
@@ -118,16 +118,8 @@
using Base::IsVectorAtCompileTime;
using Base::Flags;
- template<typename PlainObjectType, int MapOptions, typename StrideType> friend class Eigen::Map;
- friend class Eigen::Map<Derived, Unaligned>;
typedef Eigen::Map<Derived, Unaligned> MapType;
- friend class Eigen::Map<const Derived, Unaligned>;
typedef const Eigen::Map<const Derived, Unaligned> ConstMapType;
-#if EIGEN_MAX_ALIGN_BYTES>0
- // for EIGEN_MAX_ALIGN_BYTES==0, AlignedMax==Unaligned, and many compilers generate warnings for friend-ing a class twice.
- friend class Eigen::Map<Derived, AlignedMax>;
- friend class Eigen::Map<const Derived, AlignedMax>;
-#endif
typedef Eigen::Map<Derived, AlignedMax> AlignedMapType;
typedef const Eigen::Map<const Derived, AlignedMax> ConstAlignedMapType;
template<typename StrideType> struct StridedMapType { typedef Eigen::Map<Derived, Unaligned, StrideType> type; };
@@ -147,10 +139,10 @@
EIGEN_DEVICE_FUNC
const Base& base() const { return *static_cast<const Base*>(this); }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index rows() const { return m_storage.rows(); }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Index cols() const { return m_storage.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_storage.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_storage.cols(); }
/** This is an overloaded version of DenseCoeffsBase<Derived,ReadOnlyAccessors>::coeff(Index,Index) const
* provided to by-pass the creation of an evaluator of the expression, thus saving compilation efforts.
@@ -358,7 +350,7 @@
* remain row-vectors and vectors remain vectors.
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void resizeLike(const EigenBase<OtherDerived>& _other)
{
const OtherDerived& other = _other.derived();
@@ -383,7 +375,7 @@
* of rows and/or of columns, you can use conservativeResize(NoChange_t, Index) or
* conservativeResize(Index, NoChange_t).
*
- * Matrices are resized relative to the top-left element. In case values need to be
+ * Matrices are resized relative to the top-left element. In case values need to be
* appended to the matrix they will be uninitialized.
*/
EIGEN_DEVICE_FUNC
@@ -440,7 +432,7 @@
* of rows and/or of columns, you can use conservativeResize(NoChange_t, Index) or
* conservativeResize(Index, NoChange_t).
*
- * Matrices are resized relative to the top-left element. In case values need to be
+ * Matrices are resized relative to the top-left element. In case values need to be
* appended to the matrix they will copied from \c other.
*/
template<typename OtherDerived>
@@ -508,8 +500,8 @@
EIGEN_DEVICE_FUNC
PlainObjectBase& operator=(PlainObjectBase&& other) EIGEN_NOEXCEPT
{
- using std::swap;
- swap(m_storage, other.m_storage);
+ _check_template_params();
+ m_storage = std::move(other.m_storage);
return *this;
}
#endif
@@ -526,6 +518,71 @@
// EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
}
+ #if EIGEN_HAS_CXX11
+ /** \brief Construct a row of column vector with fixed size from an arbitrary number of coefficients. \cpp11
+ *
+ * \only_for_vectors
+ *
+ * This constructor is for 1D array or vectors with more than 4 coefficients.
+ * There exists C++98 analogue constructors for fixed-size array/vector having 1, 2, 3, or 4 coefficients.
+ *
+ * \warning To construct a column (resp. row) vector of fixed length, the number of values passed to this
+ * constructor must match the the fixed number of rows (resp. columns) of \c *this.
+ */
+ template <typename... ArgTypes>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ PlainObjectBase(const Scalar& a0, const Scalar& a1, const Scalar& a2, const Scalar& a3, const ArgTypes&... args)
+ : m_storage()
+ {
+ _check_template_params();
+ EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, sizeof...(args) + 4);
+ m_storage.data()[0] = a0;
+ m_storage.data()[1] = a1;
+ m_storage.data()[2] = a2;
+ m_storage.data()[3] = a3;
+ Index i = 4;
+ auto x = {(m_storage.data()[i++] = args, 0)...};
+ static_cast<void>(x);
+ }
+
+ /** \brief Constructs a Matrix or Array and initializes it by elements given by an initializer list of initializer
+ * lists \cpp11
+ */
+ EIGEN_DEVICE_FUNC
+ explicit EIGEN_STRONG_INLINE PlainObjectBase(const std::initializer_list<std::initializer_list<Scalar>>& list)
+ : m_storage()
+ {
+ _check_template_params();
+
+ size_t list_size = 0;
+ if (list.begin() != list.end()) {
+ list_size = list.begin()->size();
+ }
+
+ // This is to allow syntax like VectorXi {{1, 2, 3, 4}}
+ if (ColsAtCompileTime == 1 && list.size() == 1) {
+ eigen_assert(list_size == static_cast<size_t>(RowsAtCompileTime) || RowsAtCompileTime == Dynamic);
+ resize(list_size, ColsAtCompileTime);
+ std::copy(list.begin()->begin(), list.begin()->end(), m_storage.data());
+ } else {
+ eigen_assert(list.size() == static_cast<size_t>(RowsAtCompileTime) || RowsAtCompileTime == Dynamic);
+ eigen_assert(list_size == static_cast<size_t>(ColsAtCompileTime) || ColsAtCompileTime == Dynamic);
+ resize(list.size(), list_size);
+
+ Index row_index = 0;
+ for (const std::initializer_list<Scalar>& row : list) {
+ eigen_assert(list_size == row.size());
+ Index col_index = 0;
+ for (const Scalar& e : row) {
+ coeffRef(row_index, col_index) = e;
+ ++col_index;
+ }
+ ++row_index;
+ }
+ }
+ }
+ #endif // end EIGEN_HAS_CXX11
+
/** \sa PlainObjectBase::operator=(const EigenBase<OtherDerived>&) */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
@@ -564,7 +621,7 @@
* \copydetails DenseBase::operator=(const EigenBase<OtherDerived> &other)
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& operator=(const EigenBase<OtherDerived> &other)
{
_resize_to_match(other);
@@ -652,18 +709,26 @@
using Base::setConstant;
EIGEN_DEVICE_FUNC Derived& setConstant(Index size, const Scalar& val);
EIGEN_DEVICE_FUNC Derived& setConstant(Index rows, Index cols, const Scalar& val);
+ EIGEN_DEVICE_FUNC Derived& setConstant(NoChange_t, Index cols, const Scalar& val);
+ EIGEN_DEVICE_FUNC Derived& setConstant(Index rows, NoChange_t, const Scalar& val);
using Base::setZero;
EIGEN_DEVICE_FUNC Derived& setZero(Index size);
EIGEN_DEVICE_FUNC Derived& setZero(Index rows, Index cols);
+ EIGEN_DEVICE_FUNC Derived& setZero(NoChange_t, Index cols);
+ EIGEN_DEVICE_FUNC Derived& setZero(Index rows, NoChange_t);
using Base::setOnes;
EIGEN_DEVICE_FUNC Derived& setOnes(Index size);
EIGEN_DEVICE_FUNC Derived& setOnes(Index rows, Index cols);
+ EIGEN_DEVICE_FUNC Derived& setOnes(NoChange_t, Index cols);
+ EIGEN_DEVICE_FUNC Derived& setOnes(Index rows, NoChange_t);
using Base::setRandom;
Derived& setRandom(Index size);
Derived& setRandom(Index rows, Index cols);
+ Derived& setRandom(NoChange_t, Index cols);
+ Derived& setRandom(Index rows, NoChange_t);
#ifdef EIGEN_PLAINOBJECTBASE_PLUGIN
#include EIGEN_PLAINOBJECTBASE_PLUGIN
@@ -678,7 +743,7 @@
* remain row-vectors and vectors remain vectors.
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _resize_to_match(const EigenBase<OtherDerived>& other)
{
#ifdef EIGEN_NO_AUTOMATIC_RESIZING
@@ -705,10 +770,10 @@
*
* \internal
*/
- // aliasing is dealt once in internall::call_assignment
+ // aliasing is dealt once in internal::call_assignment
// so at this stage we have to assume aliasing... and resising has to be done later.
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& _set(const DenseBase<OtherDerived>& other)
{
internal::call_assignment(this->derived(), other.derived());
@@ -721,7 +786,7 @@
* \sa operator=(const MatrixBase<OtherDerived>&), _set()
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& _set_noalias(const DenseBase<OtherDerived>& other)
{
// I don't think we need this resize call since the lazyAssign will anyways resize
@@ -737,23 +802,25 @@
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init2(Index rows, Index cols, typename internal::enable_if<Base::SizeAtCompileTime!=2,T0>::type* = 0)
{
- EIGEN_STATIC_ASSERT(bool(NumTraits<T0>::IsInteger) &&
- bool(NumTraits<T1>::IsInteger),
+ const bool t0_is_integer_alike = internal::is_valid_index_type<T0>::value;
+ const bool t1_is_integer_alike = internal::is_valid_index_type<T1>::value;
+ EIGEN_STATIC_ASSERT(t0_is_integer_alike &&
+ t1_is_integer_alike,
FLOATING_POINT_ARGUMENT_PASSED__INTEGER_WAS_EXPECTED)
resize(rows,cols);
}
-
+
template<typename T0, typename T1>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init2(const T0& val0, const T1& val1, typename internal::enable_if<Base::SizeAtCompileTime==2,T0>::type* = 0)
{
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, 2)
m_storage.data()[0] = Scalar(val0);
m_storage.data()[1] = Scalar(val1);
}
-
+
template<typename T0, typename T1>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init2(const Index& val0, const Index& val1,
typename internal::enable_if< (!internal::is_same<Index,Scalar>::value)
&& (internal::is_same<T0,Index>::value)
@@ -773,14 +840,14 @@
&& ((!internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value || Base::SizeAtCompileTime==Dynamic)),T>::type* = 0)
{
// NOTE MSVC 2008 complains if we directly put bool(NumTraits<T>::IsInteger) as the EIGEN_STATIC_ASSERT argument.
- const bool is_integer = NumTraits<T>::IsInteger;
- EIGEN_UNUSED_VARIABLE(is_integer);
- EIGEN_STATIC_ASSERT(is_integer,
+ const bool is_integer_alike = internal::is_valid_index_type<T>::value;
+ EIGEN_UNUSED_VARIABLE(is_integer_alike);
+ EIGEN_STATIC_ASSERT(is_integer_alike,
FLOATING_POINT_ARGUMENT_PASSED__INTEGER_WAS_EXPECTED)
resize(size);
}
-
- // We have a 1x1 matrix/array => the argument is interpreted as the value of the unique coefficient (case where scalar type can be implicitely converted)
+
+ // We have a 1x1 matrix/array => the argument is interpreted as the value of the unique coefficient (case where scalar type can be implicitly converted)
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init1(const Scalar& val0, typename internal::enable_if<Base::SizeAtCompileTime==1 && internal::is_convertible<T, Scalar>::value,T>::type* = 0)
@@ -788,7 +855,7 @@
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, 1)
m_storage.data()[0] = val0;
}
-
+
// We have a 1x1 matrix/array => the argument is interpreted as the value of the unique coefficient (case where scalar type match the index type)
template<typename T>
EIGEN_DEVICE_FUNC
@@ -844,7 +911,7 @@
{
this->derived() = r;
}
-
+
// For fixed-size Array<Scalar,...>
template<typename T>
EIGEN_DEVICE_FUNC
@@ -856,7 +923,7 @@
{
Base::setConstant(val0);
}
-
+
// For fixed-size Array<Index,...>
template<typename T>
EIGEN_DEVICE_FUNC
@@ -870,38 +937,38 @@
{
Base::setConstant(val0);
}
-
+
template<typename MatrixTypeA, typename MatrixTypeB, bool SwapPointers>
friend struct internal::matrix_swap_impl;
public:
-
+
#ifndef EIGEN_PARSED_BY_DOXYGEN
/** \internal
* \brief Override DenseBase::swap() since for dynamic-sized matrices
* of same type it is enough to swap the data pointers.
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void swap(DenseBase<OtherDerived> & other)
{
enum { SwapPointers = internal::is_same<Derived, OtherDerived>::value && Base::SizeAtCompileTime==Dynamic };
internal::matrix_swap_impl<Derived, OtherDerived, bool(SwapPointers)>::run(this->derived(), other.derived());
}
-
+
/** \internal
* \brief const version forwarded to DenseBase::swap
*/
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void swap(DenseBase<OtherDerived> const & other)
{ Base::swap(other.derived()); }
-
- EIGEN_DEVICE_FUNC
+
+ EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE void _check_template_params()
{
- EIGEN_STATIC_ASSERT((EIGEN_IMPLIES(MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1, (Options&RowMajor)==RowMajor)
- && EIGEN_IMPLIES(MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1, (Options&RowMajor)==0)
+ EIGEN_STATIC_ASSERT((EIGEN_IMPLIES(MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1, (int(Options)&RowMajor)==RowMajor)
+ && EIGEN_IMPLIES(MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1, (int(Options)&RowMajor)==0)
&& ((RowsAtCompileTime == Dynamic) || (RowsAtCompileTime >= 0))
&& ((ColsAtCompileTime == Dynamic) || (ColsAtCompileTime >= 0))
&& ((MaxRowsAtCompileTime == Dynamic) || (MaxRowsAtCompileTime >= 0))
@@ -914,6 +981,17 @@
enum { IsPlainObjectBase = 1 };
#endif
+ public:
+ // These apparently need to be down here for nvcc+icc to prevent duplicate
+ // Map symbol.
+ template<typename PlainObjectType, int MapOptions, typename StrideType> friend class Eigen::Map;
+ friend class Eigen::Map<Derived, Unaligned>;
+ friend class Eigen::Map<const Derived, Unaligned>;
+#if EIGEN_MAX_ALIGN_BYTES>0
+ // for EIGEN_MAX_ALIGN_BYTES==0, AlignedMax==Unaligned, and many compilers generate warnings for friend-ing a class twice.
+ friend class Eigen::Map<Derived, AlignedMax>;
+ friend class Eigen::Map<const Derived, AlignedMax>;
+#endif
};
namespace internal {
@@ -921,13 +999,19 @@
template <typename Derived, typename OtherDerived, bool IsVector>
struct conservative_resize_like_impl
{
+ #if EIGEN_HAS_TYPE_TRAITS
+ static const bool IsRelocatable = std::is_trivially_copyable<typename Derived::Scalar>::value;
+ #else
+ static const bool IsRelocatable = !NumTraits<typename Derived::Scalar>::RequireInitialization;
+ #endif
static void run(DenseBase<Derived>& _this, Index rows, Index cols)
{
if (_this.rows() == rows && _this.cols() == cols) return;
EIGEN_STATIC_ASSERT_DYNAMIC_SIZE(Derived)
- if ( ( Derived::IsRowMajor && _this.cols() == cols) || // row-major and we change only the number of rows
- (!Derived::IsRowMajor && _this.rows() == rows) ) // column-major and we change only the number of columns
+ if ( IsRelocatable
+ && (( Derived::IsRowMajor && _this.cols() == cols) || // row-major and we change only the number of rows
+ (!Derived::IsRowMajor && _this.rows() == rows) )) // column-major and we change only the number of columns
{
internal::check_rows_cols_for_overflow<Derived::MaxSizeAtCompileTime>::run(rows, cols);
_this.derived().m_storage.conservativeResize(rows*cols,rows,cols);
@@ -935,7 +1019,7 @@
else
{
// The storage order does not allow us to use reallocation.
- typename Derived::PlainObject tmp(rows,cols);
+ Derived tmp(rows,cols);
const Index common_rows = numext::mini(rows, _this.rows());
const Index common_cols = numext::mini(cols, _this.cols());
tmp.block(0,0,common_rows,common_cols) = _this.block(0,0,common_rows,common_cols);
@@ -955,8 +1039,9 @@
EIGEN_STATIC_ASSERT_DYNAMIC_SIZE(Derived)
EIGEN_STATIC_ASSERT_DYNAMIC_SIZE(OtherDerived)
- if ( ( Derived::IsRowMajor && _this.cols() == other.cols()) || // row-major and we change only the number of rows
- (!Derived::IsRowMajor && _this.rows() == other.rows()) ) // column-major and we change only the number of columns
+ if ( IsRelocatable &&
+ (( Derived::IsRowMajor && _this.cols() == other.cols()) || // row-major and we change only the number of rows
+ (!Derived::IsRowMajor && _this.rows() == other.rows()) )) // column-major and we change only the number of columns
{
const Index new_rows = other.rows() - _this.rows();
const Index new_cols = other.cols() - _this.cols();
@@ -969,7 +1054,7 @@
else
{
// The storage order does not allow us to use reallocation.
- typename Derived::PlainObject tmp(other);
+ Derived tmp(other);
const Index common_rows = numext::mini(tmp.rows(), _this.rows());
const Index common_cols = numext::mini(tmp.cols(), _this.cols());
tmp.block(0,0,common_rows,common_cols) = _this.block(0,0,common_rows,common_cols);
@@ -984,13 +1069,18 @@
struct conservative_resize_like_impl<Derived,OtherDerived,true>
: conservative_resize_like_impl<Derived,OtherDerived,false>
{
- using conservative_resize_like_impl<Derived,OtherDerived,false>::run;
-
+ typedef conservative_resize_like_impl<Derived,OtherDerived,false> Base;
+ using Base::run;
+ using Base::IsRelocatable;
+
static void run(DenseBase<Derived>& _this, Index size)
{
const Index new_rows = Derived::RowsAtCompileTime==1 ? 1 : size;
const Index new_cols = Derived::RowsAtCompileTime==1 ? size : 1;
- _this.derived().m_storage.conservativeResize(size,new_rows,new_cols);
+ if(IsRelocatable)
+ _this.derived().m_storage.conservativeResize(size,new_rows,new_cols);
+ else
+ Base::run(_this.derived(), new_rows, new_cols);
}
static void run(DenseBase<Derived>& _this, const DenseBase<OtherDerived>& other)
@@ -1001,7 +1091,10 @@
const Index new_rows = Derived::RowsAtCompileTime==1 ? 1 : other.rows();
const Index new_cols = Derived::RowsAtCompileTime==1 ? other.cols() : 1;
- _this.derived().m_storage.conservativeResize(other.size(),new_rows,new_cols);
+ if(IsRelocatable)
+ _this.derived().m_storage.conservativeResize(other.size(),new_rows,new_cols);
+ else
+ Base::run(_this.derived(), new_rows, new_cols);
if (num_new_elements > 0)
_this.tail(num_new_elements) = other.tail(num_new_elements);
@@ -1012,7 +1105,7 @@
struct matrix_swap_impl
{
EIGEN_DEVICE_FUNC
- static inline void run(MatrixTypeA& a, MatrixTypeB& b)
+ static EIGEN_STRONG_INLINE void run(MatrixTypeA& a, MatrixTypeB& b)
{
a.base().swap(b);
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Product.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Product.h
index 676c480..70a6c10 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Product.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Product.h
@@ -23,25 +23,25 @@
typedef typename remove_all<Rhs>::type RhsCleaned;
typedef traits<LhsCleaned> LhsTraits;
typedef traits<RhsCleaned> RhsTraits;
-
+
typedef MatrixXpr XprKind;
-
+
typedef typename ScalarBinaryOpTraits<typename traits<LhsCleaned>::Scalar, typename traits<RhsCleaned>::Scalar>::ReturnType Scalar;
typedef typename product_promote_storage_type<typename LhsTraits::StorageKind,
typename RhsTraits::StorageKind,
internal::product_type<Lhs,Rhs>::ret>::ret StorageKind;
typedef typename promote_index_type<typename LhsTraits::StorageIndex,
typename RhsTraits::StorageIndex>::type StorageIndex;
-
+
enum {
RowsAtCompileTime = LhsTraits::RowsAtCompileTime,
ColsAtCompileTime = RhsTraits::ColsAtCompileTime,
MaxRowsAtCompileTime = LhsTraits::MaxRowsAtCompileTime,
MaxColsAtCompileTime = RhsTraits::MaxColsAtCompileTime,
-
+
// FIXME: only needed by GeneralMatrixMatrixTriangular
InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(LhsTraits::ColsAtCompileTime, RhsTraits::RowsAtCompileTime),
-
+
// The storage order is somewhat arbitrary here. The correct one will be determined through the evaluator.
Flags = (MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1) ? RowMajorBit
: (MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1) ? 0
@@ -74,10 +74,10 @@
internal::product_type<_Lhs,_Rhs>::ret>::ret>
{
public:
-
+
typedef _Lhs Lhs;
typedef _Rhs Rhs;
-
+
typedef typename ProductImpl<
Lhs, Rhs, Option,
typename internal::product_promote_storage_type<typename internal::traits<Lhs>::StorageKind,
@@ -90,18 +90,23 @@
typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
- EIGEN_DEVICE_FUNC Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs)
{
eigen_assert(lhs.cols() == rhs.rows()
&& "invalid matrix product"
&& "if you wanted a coeff-wise or a dot product use the respective explicit functions");
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_lhs.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); }
- EIGEN_DEVICE_FUNC const LhsNestedCleaned& lhs() const { return m_lhs; }
- EIGEN_DEVICE_FUNC const RhsNestedCleaned& rhs() const { return m_rhs; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const LhsNestedCleaned& lhs() const { return m_lhs; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const RhsNestedCleaned& rhs() const { return m_rhs; }
protected:
@@ -110,13 +115,13 @@
};
namespace internal {
-
+
template<typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs,Rhs>::ret>
class dense_product_base
: public internal::dense_xpr_base<Product<Lhs,Rhs,Option> >::type
{};
-/** Convertion to scalar for inner-products */
+/** Conversion to scalar for inner-products */
template<typename Lhs, typename Rhs, int Option>
class dense_product_base<Lhs, Rhs, Option, InnerProduct>
: public internal::dense_xpr_base<Product<Lhs,Rhs,Option> >::type
@@ -126,8 +131,8 @@
public:
using Base::derived;
typedef typename Base::Scalar Scalar;
-
- EIGEN_STRONG_INLINE operator const Scalar() const
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator const Scalar() const
{
return internal::evaluator<ProductXpr>(derived()).coeff(0,0);
}
@@ -148,25 +153,25 @@
: public internal::dense_product_base<Lhs,Rhs,Option>
{
typedef Product<Lhs, Rhs, Option> Derived;
-
+
public:
-
+
typedef typename internal::dense_product_base<Lhs, Rhs, Option> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
protected:
enum {
- IsOneByOne = (RowsAtCompileTime == 1 || RowsAtCompileTime == Dynamic) &&
+ IsOneByOne = (RowsAtCompileTime == 1 || RowsAtCompileTime == Dynamic) &&
(ColsAtCompileTime == 1 || ColsAtCompileTime == Dynamic),
EnableCoeff = IsOneByOne || Option==LazyProduct
};
-
+
public:
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index row, Index col) const
{
EIGEN_STATIC_ASSERT(EnableCoeff, THIS_METHOD_IS_ONLY_FOR_INNER_OR_LAZY_PRODUCTS);
eigen_assert( (Option==LazyProduct) || (this->rows() == 1 && this->cols() == 1) );
-
+
return internal::evaluator<Derived>(derived()).coeff(row,col);
}
@@ -174,11 +179,11 @@
{
EIGEN_STATIC_ASSERT(EnableCoeff, THIS_METHOD_IS_ONLY_FOR_INNER_OR_LAZY_PRODUCTS);
eigen_assert( (Option==LazyProduct) || (this->rows() == 1 && this->cols() == 1) );
-
+
return internal::evaluator<Derived>(derived()).coeff(i);
}
-
-
+
+
};
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ProductEvaluators.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ProductEvaluators.h
index 9b99bd7..8cf294b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ProductEvaluators.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ProductEvaluators.h
@@ -14,27 +14,27 @@
#define EIGEN_PRODUCTEVALUATORS_H
namespace Eigen {
-
+
namespace internal {
/** \internal
* Evaluator of a product expression.
* Since products require special treatments to handle all possible cases,
- * we simply deffer the evaluation logic to a product_evaluator class
+ * we simply defer the evaluation logic to a product_evaluator class
* which offers more partial specialization possibilities.
- *
+ *
* \sa class product_evaluator
*/
template<typename Lhs, typename Rhs, int Options>
-struct evaluator<Product<Lhs, Rhs, Options> >
+struct evaluator<Product<Lhs, Rhs, Options> >
: public product_evaluator<Product<Lhs, Rhs, Options> >
{
typedef Product<Lhs, Rhs, Options> XprType;
typedef product_evaluator<XprType> Base;
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit evaluator(const XprType& xpr) : Base(xpr) {}
};
-
+
// Catch "scalar * ( A * B )" and transform it to "(A*scalar) * B"
// TODO we should apply that rule only if that's really helpful
template<typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
@@ -62,12 +62,12 @@
template<typename Lhs, typename Rhs, int DiagIndex>
-struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> >
+struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> >
: public evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> >
{
typedef Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> XprType;
typedef evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> > Base;
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit evaluator(const XprType& xpr)
: Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()),
@@ -108,27 +108,27 @@
: m_result(xpr.rows(), xpr.cols())
{
::new (static_cast<Base*>(this)) Base(m_result);
-
+
// FIXME shall we handle nested_eval here?,
// if so, then we must take care at removing the call to nested_eval in the specializations (e.g., in permutation_matrix_product, transposition_matrix_product, etc.)
// typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
// typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
// typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
// typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
-//
+//
// const LhsNested lhs(xpr.lhs());
// const RhsNested rhs(xpr.rhs());
-//
+//
// generic_product_impl<LhsNestedCleaned, RhsNestedCleaned>::evalTo(m_result, lhs, rhs);
generic_product_impl<Lhs, Rhs, LhsShape, RhsShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
}
-
-protected:
+
+protected:
PlainObject m_result;
};
-// The following three shortcuts are enabled only if the scalar types match excatly.
+// The following three shortcuts are enabled only if the scalar types match exactly.
// TODO: we could enable them for different scalar types when the product is not vectorized.
// Dense = Product
@@ -137,7 +137,7 @@
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
- static EIGEN_STRONG_INLINE
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
{
Index dstRows = src.rows();
@@ -155,7 +155,7 @@
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
- static EIGEN_STRONG_INLINE
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar,Scalar> &)
{
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
@@ -170,7 +170,7 @@
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
- static EIGEN_STRONG_INLINE
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar,Scalar> &)
{
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
@@ -190,7 +190,7 @@
typedef CwiseBinaryOp<internal::scalar_product_op<ScalarBis,Scalar>,
const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>,Plain>,
const Product<Lhs,Rhs,DefaultProduct> > SrcXprType;
- static EIGEN_STRONG_INLINE
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const AssignFunc& func)
{
call_assignment_no_alias(dst, (src.lhs().functor().m_other * src.rhs().lhs())*src.rhs().rhs(), func);
@@ -217,7 +217,7 @@
struct assignment_from_xpr_op_product
{
template<typename SrcXprType, typename InitialFunc>
- static EIGEN_STRONG_INLINE
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const InitialFunc& /*func*/)
{
call_assignment_no_alias(dst, src.lhs(), Func1());
@@ -246,19 +246,19 @@
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct>
{
template<typename Dst>
- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
dst.coeffRef(0,0) = (lhs.transpose().cwiseProduct(rhs)).sum();
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
dst.coeffRef(0,0) += (lhs.transpose().cwiseProduct(rhs)).sum();
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{ dst.coeffRef(0,0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); }
};
@@ -269,10 +269,10 @@
// Column major result
template<typename Dst, typename Lhs, typename Rhs, typename Func>
-void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const false_type&)
+void EIGEN_DEVICE_FUNC outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const false_type&)
{
evaluator<Rhs> rhsEval(rhs);
- typename nested_eval<Lhs,Rhs::SizeAtCompileTime>::type actual_lhs(lhs);
+ ei_declare_local_nested_eval(Lhs,lhs,Rhs::SizeAtCompileTime,actual_lhs);
// FIXME if cols is large enough, then it might be useful to make sure that lhs is sequentially stored
// FIXME not very good if rhs is real and lhs complex while alpha is real too
const Index cols = dst.cols();
@@ -282,10 +282,10 @@
// Row major result
template<typename Dst, typename Lhs, typename Rhs, typename Func>
-void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const true_type&)
+void EIGEN_DEVICE_FUNC outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const true_type&)
{
evaluator<Lhs> lhsEval(lhs);
- typename nested_eval<Rhs,Lhs::SizeAtCompileTime>::type actual_rhs(rhs);
+ ei_declare_local_nested_eval(Rhs,rhs,Lhs::SizeAtCompileTime,actual_rhs);
// FIXME if rows is large enough, then it might be useful to make sure that rhs is sequentially stored
// FIXME not very good if lhs is real and rhs complex while alpha is real too
const Index rows = dst.rows();
@@ -298,43 +298,43 @@
{
template<typename T> struct is_row_major : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
// TODO it would be nice to be able to exploit our *_assign_op functors for that purpose
- struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
- struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
- struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
+ struct set { template<typename Dst, typename Src> EIGEN_DEVICE_FUNC void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
+ struct add { template<typename Dst, typename Src> EIGEN_DEVICE_FUNC void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
+ struct sub { template<typename Dst, typename Src> EIGEN_DEVICE_FUNC void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
struct adds {
Scalar m_scale;
explicit adds(const Scalar& s) : m_scale(s) {}
- template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
+ template<typename Dst, typename Src> void EIGEN_DEVICE_FUNC operator()(const Dst& dst, const Src& src) const {
dst.const_cast_derived() += m_scale * src;
}
};
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
internal::outer_product_selector_run(dst, lhs, rhs, set(), is_row_major<Dst>());
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
internal::outer_product_selector_run(dst, lhs, rhs, add(), is_row_major<Dst>());
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
internal::outer_product_selector_run(dst, lhs, rhs, sub(), is_row_major<Dst>());
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), is_row_major<Dst>());
}
-
+
};
@@ -343,21 +343,21 @@
struct generic_product_impl_base
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{ dst.setZero(); scaleAndAddTo(dst, lhs, rhs, Scalar(1)); }
template<typename Dst>
- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{ scaleAndAddTo(dst,lhs, rhs, Scalar(1)); }
template<typename Dst>
- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{ scaleAndAddTo(dst, lhs, rhs, Scalar(-1)); }
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{ Derived::scaleAndAddTo(dst,lhs,rhs,alpha); }
};
@@ -373,8 +373,13 @@
typedef typename internal::remove_all<typename internal::conditional<int(Side)==OnTheRight,LhsNested,RhsNested>::type>::type MatrixType;
template<typename Dest>
- static EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
+ // Fallback to inner product if both the lhs and rhs is a runtime vector.
+ if (lhs.rows() == 1 && rhs.cols() == 1) {
+ dst.coeffRef(0,0) += alpha * lhs.row(0).conjugate().dot(rhs.col(0));
+ return;
+ }
LhsNested actual_lhs(lhs);
RhsNested actual_rhs(rhs);
internal::gemv_dense_selector<Side,
@@ -385,35 +390,84 @@
};
template<typename Lhs, typename Rhs>
-struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
+struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
// Same as: dst.noalias() = lhs.lazyProduct(rhs);
// but easier on the compiler side
call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::assign_op<typename Dst::Scalar,Scalar>());
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
// dst.noalias() += lhs.lazyProduct(rhs);
call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::add_assign_op<typename Dst::Scalar,Scalar>());
}
-
+
template<typename Dst>
- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
// dst.noalias() -= lhs.lazyProduct(rhs);
call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op<typename Dst::Scalar,Scalar>());
}
-
-// template<typename Dst>
-// static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
-// { dst.noalias() += alpha * lhs.lazyProduct(rhs); }
+
+ // This is a special evaluation path called from generic_product_impl<...,GemmProduct> in file GeneralMatrixMatrix.h
+ // This variant tries to extract scalar multiples from both the LHS and RHS and factor them out. For instance:
+ // dst {,+,-}= (s1*A)*(B*s2)
+ // will be rewritten as:
+ // dst {,+,-}= (s1*s2) * (A.lazyProduct(B))
+ // There are at least four benefits of doing so:
+ // 1 - huge performance gain for heap-allocated matrix types as it save costly allocations.
+ // 2 - it is faster than simply by-passing the heap allocation through stack allocation.
+ // 3 - it makes this fallback consistent with the heavy GEMM routine.
+ // 4 - it fully by-passes huge stack allocation attempts when multiplying huge fixed-size matrices.
+ // (see https://stackoverflow.com/questions/54738495)
+ // For small fixed sizes matrices, howver, the gains are less obvious, it is sometimes x2 faster, but sometimes x3 slower,
+ // and the behavior depends also a lot on the compiler... This is why this re-writting strategy is currently
+ // enabled only when falling back from the main GEMM.
+ template<typename Dst, typename Func>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void eval_dynamic(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Func &func)
+ {
+ enum {
+ HasScalarFactor = blas_traits<Lhs>::HasScalarFactor || blas_traits<Rhs>::HasScalarFactor,
+ ConjLhs = blas_traits<Lhs>::NeedToConjugate,
+ ConjRhs = blas_traits<Rhs>::NeedToConjugate
+ };
+ // FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto
+ // this is important for real*complex_mat
+ Scalar actualAlpha = combine_scalar_factors<Scalar>(lhs, rhs);
+
+ eval_dynamic_impl(dst,
+ blas_traits<Lhs>::extract(lhs).template conjugateIf<ConjLhs>(),
+ blas_traits<Rhs>::extract(rhs).template conjugateIf<ConjRhs>(),
+ func,
+ actualAlpha,
+ typename conditional<HasScalarFactor,true_type,false_type>::type());
+ }
+
+protected:
+
+ template<typename Dst, typename LhsT, typename RhsT, typename Func, typename Scalar>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs, const Func &func, const Scalar& s /* == 1 */, false_type)
+ {
+ EIGEN_UNUSED_VARIABLE(s);
+ eigen_internal_assert(s==Scalar(1));
+ call_restricted_packet_assignment_no_alias(dst, lhs.lazyProduct(rhs), func);
+ }
+
+ template<typename Dst, typename LhsT, typename RhsT, typename Func, typename Scalar>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs, const Func &func, const Scalar& s, true_type)
+ {
+ call_restricted_packet_assignment_no_alias(dst, s * lhs.lazyProduct(rhs), func);
+ }
};
// This specialization enforces the use of a coefficient-based evaluation strategy
@@ -471,7 +525,7 @@
typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
-
+
typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
@@ -490,19 +544,19 @@
typedef typename find_best_packet<Scalar,ColsAtCompileTime>::type RhsVecPacketType;
enum {
-
+
LhsCoeffReadCost = LhsEtorType::CoeffReadCost,
RhsCoeffReadCost = RhsEtorType::CoeffReadCost,
CoeffReadCost = InnerSize==0 ? NumTraits<Scalar>::ReadCost
: InnerSize == Dynamic ? HugeCost
- : InnerSize * (NumTraits<Scalar>::MulCost + LhsCoeffReadCost + RhsCoeffReadCost)
+ : InnerSize * (NumTraits<Scalar>::MulCost + int(LhsCoeffReadCost) + int(RhsCoeffReadCost))
+ (InnerSize - 1) * NumTraits<Scalar>::AddCost,
Unroll = CoeffReadCost <= EIGEN_UNROLLING_LIMIT,
-
+
LhsFlags = LhsEtorType::Flags,
RhsFlags = RhsEtorType::Flags,
-
+
LhsRowMajor = LhsFlags & RowMajorBit,
RhsRowMajor = RhsFlags & RowMajorBit,
@@ -512,7 +566,7 @@
// Here, we don't care about alignment larger than the usable packet size.
LhsAlignment = EIGEN_PLAIN_ENUM_MIN(LhsEtorType::Alignment,LhsVecPacketSize*int(sizeof(typename LhsNestedCleaned::Scalar))),
RhsAlignment = EIGEN_PLAIN_ENUM_MIN(RhsEtorType::Alignment,RhsVecPacketSize*int(sizeof(typename RhsNestedCleaned::Scalar))),
-
+
SameType = is_same<typename LhsNestedCleaned::Scalar,typename RhsNestedCleaned::Scalar>::value,
CanVectorizeRhs = bool(RhsRowMajor) && (RhsFlags & PacketAccessBit) && (ColsAtCompileTime!=1),
@@ -522,12 +576,12 @@
: (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
: (bool(RhsRowMajor) && !CanVectorizeLhs),
- Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & ~RowMajorBit)
+ Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & ~RowMajorBit)
| (EvalToRowMajor ? RowMajorBit : 0)
// TODO enable vectorization for mixed types
| (SameType && (CanVectorizeLhs || CanVectorizeRhs) ? PacketAccessBit : 0)
| (XprType::IsVectorAtCompileTime ? LinearAccessBit : 0),
-
+
LhsOuterStrideBytes = int(LhsNestedCleaned::OuterStrideAtCompileTime) * int(sizeof(typename LhsNestedCleaned::Scalar)),
RhsOuterStrideBytes = int(RhsNestedCleaned::OuterStrideAtCompileTime) * int(sizeof(typename RhsNestedCleaned::Scalar)),
@@ -543,10 +597,10 @@
CanVectorizeInner = SameType
&& LhsRowMajor
&& (!RhsRowMajor)
- && (LhsFlags & RhsFlags & ActualPacketAccessBit)
- && (InnerSize % packet_traits<Scalar>::size == 0)
+ && (int(LhsFlags) & int(RhsFlags) & ActualPacketAccessBit)
+ && (int(InnerSize) % packet_traits<Scalar>::size == 0)
};
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index row, Index col) const
{
return (m_lhs.row(row).transpose().cwiseProduct( m_rhs.col(col) )).sum();
@@ -556,7 +610,8 @@
* which is why we don't set the LinearAccessBit.
* TODO: this seems possible when the result is a vector
*/
- EIGEN_DEVICE_FUNC const CoeffReturnType coeff(Index index) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const CoeffReturnType coeff(Index index) const
{
const Index row = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime==1) ? 0 : index;
const Index col = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime==1) ? index : 0;
@@ -564,6 +619,7 @@
}
template<int LoadMode, typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const PacketType packet(Index row, Index col) const
{
PacketType res;
@@ -575,6 +631,7 @@
}
template<int LoadMode, typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const PacketType packet(Index index) const
{
const Index row = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime==1) ? 0 : index;
@@ -585,7 +642,7 @@
protected:
typename internal::add_const_on_value_type<LhsNested>::type m_lhs;
typename internal::add_const_on_value_type<RhsNested>::type m_rhs;
-
+
LhsEtorType m_lhsImpl;
RhsEtorType m_rhsImpl;
@@ -603,7 +660,8 @@
enum {
Flags = Base::Flags | EvalBeforeNestingBit
};
- EIGEN_DEVICE_FUNC explicit product_evaluator(const XprType& xpr)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit product_evaluator(const XprType& xpr)
: Base(BaseProduct(xpr.lhs(),xpr.rhs()))
{}
};
@@ -615,7 +673,7 @@
template<int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res)
{
etor_product_packet_impl<RowMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res);
res = pmadd(pset1<Packet>(lhs.coeff(row, Index(UnrollingIndex-1))), rhs.template packet<LoadMode,Packet>(Index(UnrollingIndex-1), col), res);
@@ -625,7 +683,7 @@
template<int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<ColMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet &res)
{
etor_product_packet_impl<ColMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res);
res = pmadd(lhs.template packet<LoadMode,Packet>(row, Index(UnrollingIndex-1)), pset1<Packet>(rhs.coeff(Index(UnrollingIndex-1), col)), res);
@@ -635,7 +693,7 @@
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, 1, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res)
{
res = pmul(pset1<Packet>(lhs.coeff(row, Index(0))),rhs.template packet<LoadMode,Packet>(Index(0), col));
}
@@ -644,7 +702,7 @@
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<ColMajor, 1, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index /*innerDim*/, Packet &res)
{
res = pmul(lhs.template packet<LoadMode,Packet>(row, Index(0)), pset1<Packet>(rhs.coeff(Index(0), col)));
}
@@ -653,7 +711,7 @@
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Index /*innerDim*/, Packet &res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Index /*innerDim*/, Packet &res)
{
res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
}
@@ -662,7 +720,7 @@
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Index /*innerDim*/, Packet &res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Index /*innerDim*/, Packet &res)
{
res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
}
@@ -671,7 +729,7 @@
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res)
{
res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
for(Index i = 0; i < innerDim; ++i)
@@ -682,7 +740,7 @@
template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl<ColMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
{
- static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Index innerDim, Packet& res)
{
res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
for(Index i = 0; i < innerDim; ++i)
@@ -704,7 +762,7 @@
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag> >
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
template<typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
@@ -718,7 +776,7 @@
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,TriangularShape,ProductTag> >
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
template<typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
@@ -739,9 +797,10 @@
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> >
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
template<typename Dest>
- static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
+ static EIGEN_DEVICE_FUNC
+ void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
selfadjoint_product_impl<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>::run(dst, lhs.nestedExpression(), rhs, alpha);
}
@@ -752,7 +811,7 @@
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> >
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
+
template<typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
@@ -764,7 +823,7 @@
/***************************************************************************
* Diagonal products
***************************************************************************/
-
+
template<typename MatrixType, typename DiagonalType, typename Derived, int ProductOrder>
struct diagonal_product_evaluator_base
: evaluator_base<Derived>
@@ -772,17 +831,25 @@
typedef typename ScalarBinaryOpTraits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
public:
enum {
- CoeffReadCost = NumTraits<Scalar>::MulCost + evaluator<MatrixType>::CoeffReadCost + evaluator<DiagonalType>::CoeffReadCost,
-
+ CoeffReadCost = int(NumTraits<Scalar>::MulCost) + int(evaluator<MatrixType>::CoeffReadCost) + int(evaluator<DiagonalType>::CoeffReadCost),
+
MatrixFlags = evaluator<MatrixType>::Flags,
DiagFlags = evaluator<DiagonalType>::Flags,
- _StorageOrder = MatrixFlags & RowMajorBit ? RowMajor : ColMajor,
+
+ _StorageOrder = (Derived::MaxRowsAtCompileTime==1 && Derived::MaxColsAtCompileTime!=1) ? RowMajor
+ : (Derived::MaxColsAtCompileTime==1 && Derived::MaxRowsAtCompileTime!=1) ? ColMajor
+ : MatrixFlags & RowMajorBit ? RowMajor : ColMajor,
+ _SameStorageOrder = _StorageOrder == (MatrixFlags & RowMajorBit ? RowMajor : ColMajor),
+
_ScalarAccessOnDiag = !((int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheLeft)
||(int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheRight)),
_SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
// FIXME currently we need same types, but in the future the next rule should be the one
//_Vectorizable = bool(int(MatrixFlags)&PacketAccessBit) && ((!_PacketOnDiag) || (_SameTypes && bool(int(DiagFlags)&PacketAccessBit))),
- _Vectorizable = bool(int(MatrixFlags)&PacketAccessBit) && _SameTypes && (_ScalarAccessOnDiag || (bool(int(DiagFlags)&PacketAccessBit))),
+ _Vectorizable = bool(int(MatrixFlags)&PacketAccessBit)
+ && _SameTypes
+ && (_SameStorageOrder || (MatrixFlags&LinearAccessBit)==LinearAccessBit)
+ && (_ScalarAccessOnDiag || (bool(int(DiagFlags)&PacketAccessBit))),
_LinearAccessMask = (MatrixType::RowsAtCompileTime==1 || MatrixType::ColsAtCompileTime==1) ? LinearAccessBit : 0,
Flags = ((HereditaryBits|_LinearAccessMask) & (unsigned int)(MatrixFlags)) | (_Vectorizable ? PacketAccessBit : 0),
Alignment = evaluator<MatrixType>::Alignment,
@@ -791,14 +858,14 @@
|| (DiagonalType::SizeAtCompileTime==Dynamic && MatrixType::RowsAtCompileTime==1 && ProductOrder==OnTheLeft)
|| (DiagonalType::SizeAtCompileTime==Dynamic && MatrixType::ColsAtCompileTime==1 && ProductOrder==OnTheRight)
};
-
- diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
+
+ EIGEN_DEVICE_FUNC diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
: m_diagImpl(diag), m_matImpl(mat)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(NumTraits<Scalar>::MulCost);
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
{
if(AsScalarProduct)
@@ -806,7 +873,7 @@
else
return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
}
-
+
protected:
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE PacketType packet_impl(Index row, Index col, Index id, internal::true_type) const
@@ -814,7 +881,7 @@
return internal::pmul(m_matImpl.template packet<LoadMode,PacketType>(row, col),
internal::pset1<PacketType>(m_diagImpl.coeff(id)));
}
-
+
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE PacketType packet_impl(Index row, Index col, Index id, internal::false_type) const
{
@@ -825,7 +892,7 @@
return internal::pmul(m_matImpl.template packet<LoadMode,PacketType>(row, col),
m_diagImpl.template packet<DiagonalPacketLoadMode,PacketType>(id));
}
-
+
evaluator<DiagonalType> m_diagImpl;
evaluator<MatrixType> m_matImpl;
};
@@ -840,25 +907,25 @@
using Base::m_matImpl;
using Base::coeff;
typedef typename Base::Scalar Scalar;
-
+
typedef Product<Lhs, Rhs, ProductKind> XprType;
typedef typename XprType::PlainObject PlainObject;
-
- enum {
- StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
- };
+ typedef typename Lhs::DiagonalVectorType DiagonalType;
+
+
+ enum { StorageOrder = Base::_StorageOrder };
EIGEN_DEVICE_FUNC explicit product_evaluator(const XprType& xpr)
: Base(xpr.rhs(), xpr.lhs().diagonal())
{
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
{
return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
}
-
-#ifndef __CUDACC__
+
+#ifndef EIGEN_GPUCC
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const
{
@@ -867,7 +934,7 @@
return this->template packet_impl<LoadMode,PacketType>(row,col, row,
typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
}
-
+
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE PacketType packet(Index idx) const
{
@@ -886,30 +953,30 @@
using Base::m_matImpl;
using Base::coeff;
typedef typename Base::Scalar Scalar;
-
+
typedef Product<Lhs, Rhs, ProductKind> XprType;
typedef typename XprType::PlainObject PlainObject;
-
- enum { StorageOrder = int(Lhs::Flags) & RowMajorBit ? RowMajor : ColMajor };
+
+ enum { StorageOrder = Base::_StorageOrder };
EIGEN_DEVICE_FUNC explicit product_evaluator(const XprType& xpr)
: Base(xpr.lhs(), xpr.rhs().diagonal())
{
}
-
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
{
return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
}
-
-#ifndef __CUDACC__
+
+#ifndef EIGEN_GPUCC
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const
{
return this->template packet_impl<LoadMode,PacketType>(row,col, col,
typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
}
-
+
template<int LoadMode,typename PacketType>
EIGEN_STRONG_INLINE PacketType packet(Index idx) const
{
@@ -937,7 +1004,7 @@
typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
template<typename Dest, typename PermutationType>
- static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr)
{
MatrixType mat(xpr);
const Index n = Side==OnTheLeft ? mat.rows() : mat.cols();
@@ -991,7 +1058,7 @@
struct generic_product_impl<Lhs, Rhs, PermutationShape, MatrixShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
permutation_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
}
@@ -1001,7 +1068,7 @@
struct generic_product_impl<Lhs, Rhs, MatrixShape, PermutationShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
permutation_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
}
@@ -1011,7 +1078,7 @@
struct generic_product_impl<Inverse<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Inverse<Lhs>& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Inverse<Lhs>& lhs, const Rhs& rhs)
{
permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
}
@@ -1021,7 +1088,7 @@
struct generic_product_impl<Lhs, Inverse<Rhs>, MatrixShape, PermutationShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Lhs& lhs, const Inverse<Rhs>& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Inverse<Rhs>& rhs)
{
permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
}
@@ -1043,9 +1110,9 @@
{
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
-
+
template<typename Dest, typename TranspositionType>
- static inline void run(Dest& dst, const TranspositionType& tr, const ExpressionType& xpr)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const TranspositionType& tr, const ExpressionType& xpr)
{
MatrixType mat(xpr);
typedef typename TranspositionType::StorageIndex StorageIndex;
@@ -1068,7 +1135,7 @@
struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
}
@@ -1078,7 +1145,7 @@
struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
}
@@ -1089,7 +1156,7 @@
struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
}
@@ -1099,7 +1166,7 @@
struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag>
{
template<typename Dest>
- static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Random.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Random.h
index 6faf789..dab2ac8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Random.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Random.h
@@ -128,7 +128,7 @@
* \sa class CwiseNullaryOp, setRandom(Index), setRandom(Index,Index)
*/
template<typename Derived>
-inline Derived& DenseBase<Derived>::setRandom()
+EIGEN_DEVICE_FUNC inline Derived& DenseBase<Derived>::setRandom()
{
return *this = Random(rows(), cols());
}
@@ -177,6 +177,42 @@
return setRandom();
}
+/** Resizes to the given size, changing only the number of columns, and sets all
+ * coefficients in this expression to random values. For the parameter of type
+ * NoChange_t, just pass the special value \c NoChange.
+ *
+ * Numbers are uniformly spread through their whole definition range for integer types,
+ * and in the [-1:1] range for floating point scalar types.
+ *
+ * \not_reentrant
+ *
+ * \sa DenseBase::setRandom(), setRandom(Index), setRandom(Index, NoChange_t), class CwiseNullaryOp, DenseBase::Random()
+ */
+template<typename Derived>
+EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setRandom(NoChange_t, Index cols)
+{
+ return setRandom(rows(), cols);
+}
+
+/** Resizes to the given size, changing only the number of rows, and sets all
+ * coefficients in this expression to random values. For the parameter of type
+ * NoChange_t, just pass the special value \c NoChange.
+ *
+ * Numbers are uniformly spread through their whole definition range for integer types,
+ * and in the [-1:1] range for floating point scalar types.
+ *
+ * \not_reentrant
+ *
+ * \sa DenseBase::setRandom(), setRandom(Index), setRandom(NoChange_t, Index), class CwiseNullaryOp, DenseBase::Random()
+ */
+template<typename Derived>
+EIGEN_STRONG_INLINE Derived&
+PlainObjectBase<Derived>::setRandom(Index rows, NoChange_t)
+{
+ return setRandom(rows, cols());
+}
+
} // end namespace Eigen
#endif // EIGEN_RANDOM_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Redux.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Redux.h
index 760e9f8..b6790d1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Redux.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Redux.h
@@ -23,23 +23,29 @@
* Part 1 : the logic deciding a strategy for vectorization and unrolling
***************************************************************************/
-template<typename Func, typename Derived>
+template<typename Func, typename Evaluator>
struct redux_traits
{
public:
- typedef typename find_best_packet<typename Derived::Scalar,Derived::SizeAtCompileTime>::type PacketType;
+ typedef typename find_best_packet<typename Evaluator::Scalar,Evaluator::SizeAtCompileTime>::type PacketType;
enum {
PacketSize = unpacket_traits<PacketType>::size,
- InnerMaxSize = int(Derived::IsRowMajor)
- ? Derived::MaxColsAtCompileTime
- : Derived::MaxRowsAtCompileTime
+ InnerMaxSize = int(Evaluator::IsRowMajor)
+ ? Evaluator::MaxColsAtCompileTime
+ : Evaluator::MaxRowsAtCompileTime,
+ OuterMaxSize = int(Evaluator::IsRowMajor)
+ ? Evaluator::MaxRowsAtCompileTime
+ : Evaluator::MaxColsAtCompileTime,
+ SliceVectorizedWork = int(InnerMaxSize)==Dynamic ? Dynamic
+ : int(OuterMaxSize)==Dynamic ? (int(InnerMaxSize)>=int(PacketSize) ? Dynamic : 0)
+ : (int(InnerMaxSize)/int(PacketSize)) * int(OuterMaxSize)
};
enum {
- MightVectorize = (int(Derived::Flags)&ActualPacketAccessBit)
+ MightVectorize = (int(Evaluator::Flags)&ActualPacketAccessBit)
&& (functor_traits<Func>::PacketAccess),
- MayLinearVectorize = bool(MightVectorize) && (int(Derived::Flags)&LinearAccessBit),
- MaySliceVectorize = bool(MightVectorize) && int(InnerMaxSize)>=3*PacketSize
+ MayLinearVectorize = bool(MightVectorize) && (int(Evaluator::Flags)&LinearAccessBit),
+ MaySliceVectorize = bool(MightVectorize) && (int(SliceVectorizedWork)==Dynamic || int(SliceVectorizedWork)>=3)
};
public:
@@ -51,8 +57,8 @@
public:
enum {
- Cost = Derived::SizeAtCompileTime == Dynamic ? HugeCost
- : Derived::SizeAtCompileTime * Derived::CoeffReadCost + (Derived::SizeAtCompileTime-1) * functor_traits<Func>::Cost,
+ Cost = Evaluator::SizeAtCompileTime == Dynamic ? HugeCost
+ : int(Evaluator::SizeAtCompileTime) * int(Evaluator::CoeffReadCost) + (Evaluator::SizeAtCompileTime-1) * functor_traits<Func>::Cost,
UnrollingLimit = EIGEN_UNROLLING_LIMIT * (int(Traversal) == int(DefaultTraversal) ? 1 : int(PacketSize))
};
@@ -64,18 +70,20 @@
#ifdef EIGEN_DEBUG_ASSIGN
static void debug()
{
- std::cerr << "Xpr: " << typeid(typename Derived::XprType).name() << std::endl;
+ std::cerr << "Xpr: " << typeid(typename Evaluator::XprType).name() << std::endl;
std::cerr.setf(std::ios::hex, std::ios::basefield);
- EIGEN_DEBUG_VAR(Derived::Flags)
+ EIGEN_DEBUG_VAR(Evaluator::Flags)
std::cerr.unsetf(std::ios::hex);
EIGEN_DEBUG_VAR(InnerMaxSize)
+ EIGEN_DEBUG_VAR(OuterMaxSize)
+ EIGEN_DEBUG_VAR(SliceVectorizedWork)
EIGEN_DEBUG_VAR(PacketSize)
EIGEN_DEBUG_VAR(MightVectorize)
EIGEN_DEBUG_VAR(MayLinearVectorize)
EIGEN_DEBUG_VAR(MaySliceVectorize)
- EIGEN_DEBUG_VAR(Traversal)
+ std::cerr << "Traversal" << " = " << Traversal << " (" << demangle_traversal(Traversal) << ")" << std::endl;
EIGEN_DEBUG_VAR(UnrollingLimit)
- EIGEN_DEBUG_VAR(Unrolling)
+ std::cerr << "Unrolling" << " = " << Unrolling << " (" << demangle_unrolling(Unrolling) << ")" << std::endl;
std::cerr << std::endl;
}
#endif
@@ -87,88 +95,86 @@
/*** no vectorization ***/
-template<typename Func, typename Derived, int Start, int Length>
+template<typename Func, typename Evaluator, int Start, int Length>
struct redux_novec_unroller
{
enum {
HalfLength = Length/2
};
- typedef typename Derived::Scalar Scalar;
+ typedef typename Evaluator::Scalar Scalar;
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
+ static EIGEN_STRONG_INLINE Scalar run(const Evaluator &eval, const Func& func)
{
- return func(redux_novec_unroller<Func, Derived, Start, HalfLength>::run(mat,func),
- redux_novec_unroller<Func, Derived, Start+HalfLength, Length-HalfLength>::run(mat,func));
+ return func(redux_novec_unroller<Func, Evaluator, Start, HalfLength>::run(eval,func),
+ redux_novec_unroller<Func, Evaluator, Start+HalfLength, Length-HalfLength>::run(eval,func));
}
};
-template<typename Func, typename Derived, int Start>
-struct redux_novec_unroller<Func, Derived, Start, 1>
+template<typename Func, typename Evaluator, int Start>
+struct redux_novec_unroller<Func, Evaluator, Start, 1>
{
enum {
- outer = Start / Derived::InnerSizeAtCompileTime,
- inner = Start % Derived::InnerSizeAtCompileTime
+ outer = Start / Evaluator::InnerSizeAtCompileTime,
+ inner = Start % Evaluator::InnerSizeAtCompileTime
};
- typedef typename Derived::Scalar Scalar;
+ typedef typename Evaluator::Scalar Scalar;
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func&)
+ static EIGEN_STRONG_INLINE Scalar run(const Evaluator &eval, const Func&)
{
- return mat.coeffByOuterInner(outer, inner);
+ return eval.coeffByOuterInner(outer, inner);
}
};
// This is actually dead code and will never be called. It is required
// to prevent false warnings regarding failed inlining though
// for 0 length run() will never be called at all.
-template<typename Func, typename Derived, int Start>
-struct redux_novec_unroller<Func, Derived, Start, 0>
+template<typename Func, typename Evaluator, int Start>
+struct redux_novec_unroller<Func, Evaluator, Start, 0>
{
- typedef typename Derived::Scalar Scalar;
+ typedef typename Evaluator::Scalar Scalar;
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Derived&, const Func&) { return Scalar(); }
+ static EIGEN_STRONG_INLINE Scalar run(const Evaluator&, const Func&) { return Scalar(); }
};
/*** vectorization ***/
-template<typename Func, typename Derived, int Start, int Length>
+template<typename Func, typename Evaluator, int Start, int Length>
struct redux_vec_unroller
{
- enum {
- PacketSize = redux_traits<Func, Derived>::PacketSize,
- HalfLength = Length/2
- };
-
- typedef typename Derived::Scalar Scalar;
- typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
-
- static EIGEN_STRONG_INLINE PacketScalar run(const Derived &mat, const Func& func)
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE PacketType run(const Evaluator &eval, const Func& func)
{
+ enum {
+ PacketSize = unpacket_traits<PacketType>::size,
+ HalfLength = Length/2
+ };
+
return func.packetOp(
- redux_vec_unroller<Func, Derived, Start, HalfLength>::run(mat,func),
- redux_vec_unroller<Func, Derived, Start+HalfLength, Length-HalfLength>::run(mat,func) );
+ redux_vec_unroller<Func, Evaluator, Start, HalfLength>::template run<PacketType>(eval,func),
+ redux_vec_unroller<Func, Evaluator, Start+HalfLength, Length-HalfLength>::template run<PacketType>(eval,func) );
}
};
-template<typename Func, typename Derived, int Start>
-struct redux_vec_unroller<Func, Derived, Start, 1>
+template<typename Func, typename Evaluator, int Start>
+struct redux_vec_unroller<Func, Evaluator, Start, 1>
{
- enum {
- index = Start * redux_traits<Func, Derived>::PacketSize,
- outer = index / int(Derived::InnerSizeAtCompileTime),
- inner = index % int(Derived::InnerSizeAtCompileTime),
- alignment = Derived::Alignment
- };
-
- typedef typename Derived::Scalar Scalar;
- typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
-
- static EIGEN_STRONG_INLINE PacketScalar run(const Derived &mat, const Func&)
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE PacketType run(const Evaluator &eval, const Func&)
{
- return mat.template packetByOuterInner<alignment,PacketScalar>(outer, inner);
+ enum {
+ PacketSize = unpacket_traits<PacketType>::size,
+ index = Start * PacketSize,
+ outer = index / int(Evaluator::InnerSizeAtCompileTime),
+ inner = index % int(Evaluator::InnerSizeAtCompileTime),
+ alignment = Evaluator::Alignment
+ };
+ return eval.template packetByOuterInner<alignment,PacketType>(outer, inner);
}
};
@@ -176,53 +182,65 @@
* Part 3 : implementation of all cases
***************************************************************************/
-template<typename Func, typename Derived,
- int Traversal = redux_traits<Func, Derived>::Traversal,
- int Unrolling = redux_traits<Func, Derived>::Unrolling
+template<typename Func, typename Evaluator,
+ int Traversal = redux_traits<Func, Evaluator>::Traversal,
+ int Unrolling = redux_traits<Func, Evaluator>::Unrolling
>
struct redux_impl;
-template<typename Func, typename Derived>
-struct redux_impl<Func, Derived, DefaultTraversal, NoUnrolling>
+template<typename Func, typename Evaluator>
+struct redux_impl<Func, Evaluator, DefaultTraversal, NoUnrolling>
{
- typedef typename Derived::Scalar Scalar;
- EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
+ typedef typename Evaluator::Scalar Scalar;
+
+ template<typename XprType>
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE
+ Scalar run(const Evaluator &eval, const Func& func, const XprType& xpr)
{
- eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
+ eigen_assert(xpr.rows()>0 && xpr.cols()>0 && "you are using an empty matrix");
Scalar res;
- res = mat.coeffByOuterInner(0, 0);
- for(Index i = 1; i < mat.innerSize(); ++i)
- res = func(res, mat.coeffByOuterInner(0, i));
- for(Index i = 1; i < mat.outerSize(); ++i)
- for(Index j = 0; j < mat.innerSize(); ++j)
- res = func(res, mat.coeffByOuterInner(i, j));
+ res = eval.coeffByOuterInner(0, 0);
+ for(Index i = 1; i < xpr.innerSize(); ++i)
+ res = func(res, eval.coeffByOuterInner(0, i));
+ for(Index i = 1; i < xpr.outerSize(); ++i)
+ for(Index j = 0; j < xpr.innerSize(); ++j)
+ res = func(res, eval.coeffByOuterInner(i, j));
return res;
}
};
-template<typename Func, typename Derived>
-struct redux_impl<Func,Derived, DefaultTraversal, CompleteUnrolling>
- : public redux_novec_unroller<Func,Derived, 0, Derived::SizeAtCompileTime>
-{};
-
-template<typename Func, typename Derived>
-struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling>
+template<typename Func, typename Evaluator>
+struct redux_impl<Func,Evaluator, DefaultTraversal, CompleteUnrolling>
+ : redux_novec_unroller<Func,Evaluator, 0, Evaluator::SizeAtCompileTime>
{
- typedef typename Derived::Scalar Scalar;
- typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
-
- static Scalar run(const Derived &mat, const Func& func)
+ typedef redux_novec_unroller<Func,Evaluator, 0, Evaluator::SizeAtCompileTime> Base;
+ typedef typename Evaluator::Scalar Scalar;
+ template<typename XprType>
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE
+ Scalar run(const Evaluator &eval, const Func& func, const XprType& /*xpr*/)
{
- const Index size = mat.size();
+ return Base::run(eval,func);
+ }
+};
+
+template<typename Func, typename Evaluator>
+struct redux_impl<Func, Evaluator, LinearVectorizedTraversal, NoUnrolling>
+{
+ typedef typename Evaluator::Scalar Scalar;
+ typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
+
+ template<typename XprType>
+ static Scalar run(const Evaluator &eval, const Func& func, const XprType& xpr)
+ {
+ const Index size = xpr.size();
- const Index packetSize = redux_traits<Func, Derived>::PacketSize;
+ const Index packetSize = redux_traits<Func, Evaluator>::PacketSize;
const int packetAlignment = unpacket_traits<PacketScalar>::alignment;
enum {
- alignment0 = (bool(Derived::Flags & DirectAccessBit) && bool(packet_traits<Scalar>::AlignedOnScalar)) ? int(packetAlignment) : int(Unaligned),
- alignment = EIGEN_PLAIN_ENUM_MAX(alignment0, Derived::Alignment)
+ alignment0 = (bool(Evaluator::Flags & DirectAccessBit) && bool(packet_traits<Scalar>::AlignedOnScalar)) ? int(packetAlignment) : int(Unaligned),
+ alignment = EIGEN_PLAIN_ENUM_MAX(alignment0, Evaluator::Alignment)
};
- const Index alignedStart = internal::first_default_aligned(mat.nestedExpression());
+ const Index alignedStart = internal::first_default_aligned(xpr);
const Index alignedSize2 = ((size-alignedStart)/(2*packetSize))*(2*packetSize);
const Index alignedSize = ((size-alignedStart)/(packetSize))*(packetSize);
const Index alignedEnd2 = alignedStart + alignedSize2;
@@ -230,34 +248,34 @@
Scalar res;
if(alignedSize)
{
- PacketScalar packet_res0 = mat.template packet<alignment,PacketScalar>(alignedStart);
+ PacketScalar packet_res0 = eval.template packet<alignment,PacketScalar>(alignedStart);
if(alignedSize>packetSize) // we have at least two packets to partly unroll the loop
{
- PacketScalar packet_res1 = mat.template packet<alignment,PacketScalar>(alignedStart+packetSize);
+ PacketScalar packet_res1 = eval.template packet<alignment,PacketScalar>(alignedStart+packetSize);
for(Index index = alignedStart + 2*packetSize; index < alignedEnd2; index += 2*packetSize)
{
- packet_res0 = func.packetOp(packet_res0, mat.template packet<alignment,PacketScalar>(index));
- packet_res1 = func.packetOp(packet_res1, mat.template packet<alignment,PacketScalar>(index+packetSize));
+ packet_res0 = func.packetOp(packet_res0, eval.template packet<alignment,PacketScalar>(index));
+ packet_res1 = func.packetOp(packet_res1, eval.template packet<alignment,PacketScalar>(index+packetSize));
}
packet_res0 = func.packetOp(packet_res0,packet_res1);
if(alignedEnd>alignedEnd2)
- packet_res0 = func.packetOp(packet_res0, mat.template packet<alignment,PacketScalar>(alignedEnd2));
+ packet_res0 = func.packetOp(packet_res0, eval.template packet<alignment,PacketScalar>(alignedEnd2));
}
res = func.predux(packet_res0);
for(Index index = 0; index < alignedStart; ++index)
- res = func(res,mat.coeff(index));
+ res = func(res,eval.coeff(index));
for(Index index = alignedEnd; index < size; ++index)
- res = func(res,mat.coeff(index));
+ res = func(res,eval.coeff(index));
}
else // too small to vectorize anything.
// since this is dynamic-size hence inefficient anyway for such small sizes, don't try to optimize.
{
- res = mat.coeff(0);
+ res = eval.coeff(0);
for(Index index = 1; index < size; ++index)
- res = func(res,mat.coeff(index));
+ res = func(res,eval.coeff(index));
}
return res;
@@ -265,130 +283,108 @@
};
// NOTE: for SliceVectorizedTraversal we simply bypass unrolling
-template<typename Func, typename Derived, int Unrolling>
-struct redux_impl<Func, Derived, SliceVectorizedTraversal, Unrolling>
+template<typename Func, typename Evaluator, int Unrolling>
+struct redux_impl<Func, Evaluator, SliceVectorizedTraversal, Unrolling>
{
- typedef typename Derived::Scalar Scalar;
- typedef typename redux_traits<Func, Derived>::PacketType PacketType;
+ typedef typename Evaluator::Scalar Scalar;
+ typedef typename redux_traits<Func, Evaluator>::PacketType PacketType;
- EIGEN_DEVICE_FUNC static Scalar run(const Derived &mat, const Func& func)
+ template<typename XprType>
+ EIGEN_DEVICE_FUNC static Scalar run(const Evaluator &eval, const Func& func, const XprType& xpr)
{
- eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
- const Index innerSize = mat.innerSize();
- const Index outerSize = mat.outerSize();
+ eigen_assert(xpr.rows()>0 && xpr.cols()>0 && "you are using an empty matrix");
+ const Index innerSize = xpr.innerSize();
+ const Index outerSize = xpr.outerSize();
enum {
- packetSize = redux_traits<Func, Derived>::PacketSize
+ packetSize = redux_traits<Func, Evaluator>::PacketSize
};
const Index packetedInnerSize = ((innerSize)/packetSize)*packetSize;
Scalar res;
if(packetedInnerSize)
{
- PacketType packet_res = mat.template packet<Unaligned,PacketType>(0,0);
+ PacketType packet_res = eval.template packet<Unaligned,PacketType>(0,0);
for(Index j=0; j<outerSize; ++j)
for(Index i=(j==0?packetSize:0); i<packetedInnerSize; i+=Index(packetSize))
- packet_res = func.packetOp(packet_res, mat.template packetByOuterInner<Unaligned,PacketType>(j,i));
+ packet_res = func.packetOp(packet_res, eval.template packetByOuterInner<Unaligned,PacketType>(j,i));
res = func.predux(packet_res);
for(Index j=0; j<outerSize; ++j)
for(Index i=packetedInnerSize; i<innerSize; ++i)
- res = func(res, mat.coeffByOuterInner(j,i));
+ res = func(res, eval.coeffByOuterInner(j,i));
}
else // too small to vectorize anything.
// since this is dynamic-size hence inefficient anyway for such small sizes, don't try to optimize.
{
- res = redux_impl<Func, Derived, DefaultTraversal, NoUnrolling>::run(mat, func);
+ res = redux_impl<Func, Evaluator, DefaultTraversal, NoUnrolling>::run(eval, func, xpr);
}
return res;
}
};
-template<typename Func, typename Derived>
-struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
+template<typename Func, typename Evaluator>
+struct redux_impl<Func, Evaluator, LinearVectorizedTraversal, CompleteUnrolling>
{
- typedef typename Derived::Scalar Scalar;
+ typedef typename Evaluator::Scalar Scalar;
- typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
+ typedef typename redux_traits<Func, Evaluator>::PacketType PacketType;
enum {
- PacketSize = redux_traits<Func, Derived>::PacketSize,
- Size = Derived::SizeAtCompileTime,
- VectorizedSize = (Size / PacketSize) * PacketSize
+ PacketSize = redux_traits<Func, Evaluator>::PacketSize,
+ Size = Evaluator::SizeAtCompileTime,
+ VectorizedSize = (int(Size) / int(PacketSize)) * int(PacketSize)
};
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
+
+ template<typename XprType>
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE
+ Scalar run(const Evaluator &eval, const Func& func, const XprType &xpr)
{
- eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
+ EIGEN_ONLY_USED_FOR_DEBUG(xpr)
+ eigen_assert(xpr.rows()>0 && xpr.cols()>0 && "you are using an empty matrix");
if (VectorizedSize > 0) {
- Scalar res = func.predux(redux_vec_unroller<Func, Derived, 0, Size / PacketSize>::run(mat,func));
+ Scalar res = func.predux(redux_vec_unroller<Func, Evaluator, 0, Size / PacketSize>::template run<PacketType>(eval,func));
if (VectorizedSize != Size)
- res = func(res,redux_novec_unroller<Func, Derived, VectorizedSize, Size-VectorizedSize>::run(mat,func));
+ res = func(res,redux_novec_unroller<Func, Evaluator, VectorizedSize, Size-VectorizedSize>::run(eval,func));
return res;
}
else {
- return redux_novec_unroller<Func, Derived, 0, Size>::run(mat,func);
+ return redux_novec_unroller<Func, Evaluator, 0, Size>::run(eval,func);
}
}
};
// evaluator adaptor
template<typename _XprType>
-class redux_evaluator
+class redux_evaluator : public internal::evaluator<_XprType>
{
+ typedef internal::evaluator<_XprType> Base;
public:
typedef _XprType XprType;
- EIGEN_DEVICE_FUNC explicit redux_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ explicit redux_evaluator(const XprType &xpr) : Base(xpr) {}
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketScalar PacketScalar;
- typedef typename XprType::PacketReturnType PacketReturnType;
enum {
MaxRowsAtCompileTime = XprType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = XprType::MaxColsAtCompileTime,
// TODO we should not remove DirectAccessBit and rather find an elegant way to query the alignment offset at runtime from the evaluator
- Flags = evaluator<XprType>::Flags & ~DirectAccessBit,
+ Flags = Base::Flags & ~DirectAccessBit,
IsRowMajor = XprType::IsRowMajor,
SizeAtCompileTime = XprType::SizeAtCompileTime,
- InnerSizeAtCompileTime = XprType::InnerSizeAtCompileTime,
- CoeffReadCost = evaluator<XprType>::CoeffReadCost,
- Alignment = evaluator<XprType>::Alignment
+ InnerSizeAtCompileTime = XprType::InnerSizeAtCompileTime
};
- EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
- EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }
- EIGEN_DEVICE_FUNC Index size() const { return m_xpr.size(); }
- EIGEN_DEVICE_FUNC Index innerSize() const { return m_xpr.innerSize(); }
- EIGEN_DEVICE_FUNC Index outerSize() const { return m_xpr.outerSize(); }
-
- EIGEN_DEVICE_FUNC
- CoeffReturnType coeff(Index row, Index col) const
- { return m_evaluator.coeff(row, col); }
-
- EIGEN_DEVICE_FUNC
- CoeffReturnType coeff(Index index) const
- { return m_evaluator.coeff(index); }
-
- template<int LoadMode, typename PacketType>
- PacketType packet(Index row, Index col) const
- { return m_evaluator.template packet<LoadMode,PacketType>(row, col); }
-
- template<int LoadMode, typename PacketType>
- PacketType packet(Index index) const
- { return m_evaluator.template packet<LoadMode,PacketType>(index); }
-
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeffByOuterInner(Index outer, Index inner) const
- { return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
+ { return Base::coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
template<int LoadMode, typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
PacketType packetByOuterInner(Index outer, Index inner) const
- { return m_evaluator.template packet<LoadMode,PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
+ { return Base::template packet<LoadMode,PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
- const XprType & nestedExpression() const { return m_xpr; }
-
-protected:
- internal::evaluator<XprType> m_evaluator;
- const XprType &m_xpr;
};
} // end namespace internal
@@ -403,39 +399,53 @@
* The template parameter \a BinaryOp is the type of the functor \a func which must be
* an associative operator. Both current C++98 and C++11 functor styles are handled.
*
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
+ *
* \sa DenseBase::sum(), DenseBase::minCoeff(), DenseBase::maxCoeff(), MatrixBase::colwise(), MatrixBase::rowwise()
*/
template<typename Derived>
template<typename Func>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::redux(const Func& func) const
{
eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
typedef typename internal::redux_evaluator<Derived> ThisEvaluator;
ThisEvaluator thisEval(derived());
-
- return internal::redux_impl<Func, ThisEvaluator>::run(thisEval, func);
+
+ // The initial expression is passed to the reducer as an additional argument instead of
+ // passing it as a member of redux_evaluator to help
+ return internal::redux_impl<Func, ThisEvaluator>::run(thisEval, func, derived());
}
/** \returns the minimum of all coefficients of \c *this.
- * \warning the result is undefined if \c *this contains NaN.
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is minimum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+template<int NaNPropagation>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::minCoeff() const
{
- return derived().redux(Eigen::internal::scalar_min_op<Scalar,Scalar>());
+ return derived().redux(Eigen::internal::scalar_min_op<Scalar,Scalar, NaNPropagation>());
}
-/** \returns the maximum of all coefficients of \c *this.
- * \warning the result is undefined if \c *this contains NaN.
+/** \returns the maximum of all coefficients of \c *this.
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+template<int NaNPropagation>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::maxCoeff() const
{
- return derived().redux(Eigen::internal::scalar_max_op<Scalar,Scalar>());
+ return derived().redux(Eigen::internal::scalar_max_op<Scalar,Scalar, NaNPropagation>());
}
/** \returns the sum of all coefficients of \c *this
@@ -445,7 +455,7 @@
* \sa trace(), prod(), mean()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::sum() const
{
if(SizeAtCompileTime==0 || (SizeAtCompileTime==Dynamic && size()==0))
@@ -458,7 +468,7 @@
* \sa trace(), prod(), sum()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::mean() const
{
#ifdef __INTEL_COMPILER
@@ -479,7 +489,7 @@
* \sa sum(), mean(), trace()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::prod() const
{
if(SizeAtCompileTime==0 || (SizeAtCompileTime==Dynamic && size()==0))
@@ -494,7 +504,7 @@
* \sa diagonal(), sum()
*/
template<typename Derived>
-EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
MatrixBase<Derived>::trace() const
{
return derived().diagonal().sum();
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Ref.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Ref.h
index 9c6e3c5..c2a37ea 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Ref.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Ref.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_REF_H
#define EIGEN_REF_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
@@ -28,12 +28,13 @@
template<typename Derived> struct match {
enum {
+ IsVectorAtCompileTime = PlainObjectType::IsVectorAtCompileTime || Derived::IsVectorAtCompileTime,
HasDirectAccess = internal::has_direct_access<Derived>::ret,
- StorageOrderMatch = PlainObjectType::IsVectorAtCompileTime || Derived::IsVectorAtCompileTime || ((PlainObjectType::Flags&RowMajorBit)==(Derived::Flags&RowMajorBit)),
+ StorageOrderMatch = IsVectorAtCompileTime || ((PlainObjectType::Flags&RowMajorBit)==(Derived::Flags&RowMajorBit)),
InnerStrideMatch = int(StrideType::InnerStrideAtCompileTime)==int(Dynamic)
|| int(StrideType::InnerStrideAtCompileTime)==int(Derived::InnerStrideAtCompileTime)
|| (int(StrideType::InnerStrideAtCompileTime)==0 && int(Derived::InnerStrideAtCompileTime)==1),
- OuterStrideMatch = Derived::IsVectorAtCompileTime
+ OuterStrideMatch = IsVectorAtCompileTime
|| int(StrideType::OuterStrideAtCompileTime)==int(Dynamic) || int(StrideType::OuterStrideAtCompileTime)==int(Derived::OuterStrideAtCompileTime),
// NOTE, this indirection of evaluator<Derived>::Alignment is needed
// to workaround a very strange bug in MSVC related to the instantiation
@@ -47,7 +48,7 @@
};
typedef typename internal::conditional<MatchAtCompileTime,internal::true_type,internal::false_type>::type type;
};
-
+
};
template<typename Derived>
@@ -66,12 +67,12 @@
typedef MapBase<Derived> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(RefBase)
- EIGEN_DEVICE_FUNC inline Index innerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
{
return StrideType::InnerStrideAtCompileTime != 0 ? m_stride.inner() : 1;
}
- EIGEN_DEVICE_FUNC inline Index outerStride() const
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
{
return StrideType::OuterStrideAtCompileTime != 0 ? m_stride.outer()
: IsVectorAtCompileTime ? this->size()
@@ -85,36 +86,122 @@
m_stride(StrideType::OuterStrideAtCompileTime==Dynamic?0:StrideType::OuterStrideAtCompileTime,
StrideType::InnerStrideAtCompileTime==Dynamic?0:StrideType::InnerStrideAtCompileTime)
{}
-
+
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(RefBase)
protected:
typedef Stride<StrideType::OuterStrideAtCompileTime,StrideType::InnerStrideAtCompileTime> StrideBase;
- template<typename Expression>
- EIGEN_DEVICE_FUNC void construct(Expression& expr)
- {
- EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(PlainObjectType,Expression);
+ // Resolves inner stride if default 0.
+ static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index resolveInnerStride(Index inner) {
+ return inner == 0 ? 1 : inner;
+ }
+ // Resolves outer stride if default 0.
+ static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index resolveOuterStride(Index inner, Index outer, Index rows, Index cols, bool isVectorAtCompileTime, bool isRowMajor) {
+ return outer == 0 ? isVectorAtCompileTime ? inner * rows * cols : isRowMajor ? inner * cols : inner * rows : outer;
+ }
+
+ // Returns true if construction is valid, false if there is a stride mismatch,
+ // and fails if there is a size mismatch.
+ template<typename Expression>
+ EIGEN_DEVICE_FUNC bool construct(Expression& expr)
+ {
+ // Check matrix sizes. If this is a compile-time vector, we do allow
+ // implicitly transposing.
+ EIGEN_STATIC_ASSERT(
+ EIGEN_PREDICATE_SAME_MATRIX_SIZE(PlainObjectType, Expression)
+ // If it is a vector, the transpose sizes might match.
+ || ( PlainObjectType::IsVectorAtCompileTime
+ && ((int(PlainObjectType::RowsAtCompileTime)==Eigen::Dynamic
+ || int(Expression::ColsAtCompileTime)==Eigen::Dynamic
+ || int(PlainObjectType::RowsAtCompileTime)==int(Expression::ColsAtCompileTime))
+ && (int(PlainObjectType::ColsAtCompileTime)==Eigen::Dynamic
+ || int(Expression::RowsAtCompileTime)==Eigen::Dynamic
+ || int(PlainObjectType::ColsAtCompileTime)==int(Expression::RowsAtCompileTime)))),
+ YOU_MIXED_MATRICES_OF_DIFFERENT_SIZES
+ )
+
+ // Determine runtime rows and columns.
+ Index rows = expr.rows();
+ Index cols = expr.cols();
if(PlainObjectType::RowsAtCompileTime==1)
{
eigen_assert(expr.rows()==1 || expr.cols()==1);
- ::new (static_cast<Base*>(this)) Base(expr.data(), 1, expr.size());
+ rows = 1;
+ cols = expr.size();
}
else if(PlainObjectType::ColsAtCompileTime==1)
{
eigen_assert(expr.rows()==1 || expr.cols()==1);
- ::new (static_cast<Base*>(this)) Base(expr.data(), expr.size(), 1);
+ rows = expr.size();
+ cols = 1;
}
- else
- ::new (static_cast<Base*>(this)) Base(expr.data(), expr.rows(), expr.cols());
-
- if(Expression::IsVectorAtCompileTime && (!PlainObjectType::IsVectorAtCompileTime) && ((Expression::Flags&RowMajorBit)!=(PlainObjectType::Flags&RowMajorBit)))
- ::new (&m_stride) StrideBase(expr.innerStride(), StrideType::InnerStrideAtCompileTime==0?0:1);
- else
- ::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(),
- StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride());
+ // Verify that the sizes are valid.
+ eigen_assert(
+ (PlainObjectType::RowsAtCompileTime == Dynamic) || (PlainObjectType::RowsAtCompileTime == rows));
+ eigen_assert(
+ (PlainObjectType::ColsAtCompileTime == Dynamic) || (PlainObjectType::ColsAtCompileTime == cols));
+
+
+ // If this is a vector, we might be transposing, which means that stride should swap.
+ const bool transpose = PlainObjectType::IsVectorAtCompileTime && (rows != expr.rows());
+ // If the storage format differs, we also need to swap the stride.
+ const bool row_major = ((PlainObjectType::Flags)&RowMajorBit) != 0;
+ const bool expr_row_major = (Expression::Flags&RowMajorBit) != 0;
+ const bool storage_differs = (row_major != expr_row_major);
+
+ const bool swap_stride = (transpose != storage_differs);
+
+ // Determine expr's actual strides, resolving any defaults if zero.
+ const Index expr_inner_actual = resolveInnerStride(expr.innerStride());
+ const Index expr_outer_actual = resolveOuterStride(expr_inner_actual,
+ expr.outerStride(),
+ expr.rows(),
+ expr.cols(),
+ Expression::IsVectorAtCompileTime != 0,
+ expr_row_major);
+
+ // If this is a column-major row vector or row-major column vector, the inner-stride
+ // is arbitrary, so set it to either the compile-time inner stride or 1.
+ const bool row_vector = (rows == 1);
+ const bool col_vector = (cols == 1);
+ const Index inner_stride =
+ ( (!row_major && row_vector) || (row_major && col_vector) ) ?
+ ( StrideType::InnerStrideAtCompileTime > 0 ? Index(StrideType::InnerStrideAtCompileTime) : 1)
+ : swap_stride ? expr_outer_actual : expr_inner_actual;
+
+ // If this is a column-major column vector or row-major row vector, the outer-stride
+ // is arbitrary, so set it to either the compile-time outer stride or vector size.
+ const Index outer_stride =
+ ( (!row_major && col_vector) || (row_major && row_vector) ) ?
+ ( StrideType::OuterStrideAtCompileTime > 0 ? Index(StrideType::OuterStrideAtCompileTime) : rows * cols * inner_stride)
+ : swap_stride ? expr_inner_actual : expr_outer_actual;
+
+ // Check if given inner/outer strides are compatible with compile-time strides.
+ const bool inner_valid = (StrideType::InnerStrideAtCompileTime == Dynamic)
+ || (resolveInnerStride(Index(StrideType::InnerStrideAtCompileTime)) == inner_stride);
+ if (!inner_valid) {
+ return false;
+ }
+
+ const bool outer_valid = (StrideType::OuterStrideAtCompileTime == Dynamic)
+ || (resolveOuterStride(
+ inner_stride,
+ Index(StrideType::OuterStrideAtCompileTime),
+ rows, cols, PlainObjectType::IsVectorAtCompileTime != 0,
+ row_major)
+ == outer_stride);
+ if (!outer_valid) {
+ return false;
+ }
+
+ ::new (static_cast<Base*>(this)) Base(expr.data(), rows, cols);
+ ::new (&m_stride) StrideBase(
+ (StrideType::OuterStrideAtCompileTime == 0) ? 0 : outer_stride,
+ (StrideType::InnerStrideAtCompileTime == 0) ? 0 : inner_stride );
+ return true;
}
StrideBase m_stride;
@@ -186,6 +273,8 @@
* void foo(const Ref<MatrixXf,0,Stride<> >& A) { foo_impl(A); }
* \endcode
*
+ * See also the following stackoverflow questions for further references:
+ * - <a href="http://stackoverflow.com/questions/21132538/correct-usage-of-the-eigenref-class">Correct usage of the Eigen::Ref<> class</a>
*
* \sa PlainObjectBase::Map(), \ref TopicStorageOrders
*/
@@ -209,7 +298,10 @@
typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0)
{
EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
- Base::construct(expr.derived());
+ // Construction must pass since we will not create temprary storage in the non-const case.
+ const bool success = Base::construct(expr.derived());
+ EIGEN_UNUSED_VARIABLE(success)
+ eigen_assert(success);
}
template<typename Derived>
EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr,
@@ -223,7 +315,10 @@
EIGEN_STATIC_ASSERT(bool(internal::is_lvalue<Derived>::value), THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY);
EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
EIGEN_STATIC_ASSERT(!Derived::IsPlainObjectBase,THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY);
- Base::construct(expr.const_cast_derived());
+ // Construction must pass since we will not create temporary storage in the non-const case.
+ const bool success = Base::construct(expr.const_cast_derived());
+ EIGEN_UNUSED_VARIABLE(success)
+ eigen_assert(success);
}
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Ref)
@@ -264,7 +359,10 @@
template<typename Expression>
EIGEN_DEVICE_FUNC void construct(const Expression& expr,internal::true_type)
{
- Base::construct(expr);
+ // Check if we can use the underlying expr's storage directly, otherwise call the copy version.
+ if (!Base::construct(expr)) {
+ construct(expr, internal::false_type());
+ }
}
template<typename Expression>
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Replicate.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Replicate.h
index 9960ef8..ab5be7e 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Replicate.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Replicate.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_REPLICATE_H
#define EIGEN_REPLICATE_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
template<typename MatrixType,int RowFactor,int ColFactor>
@@ -35,7 +35,7 @@
IsRowMajor = MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1 ? 1
: MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1 ? 0
: (MatrixType::Flags & RowMajorBit) ? 1 : 0,
-
+
// FIXME enable DirectAccess with negative strides?
Flags = IsRowMajor ? RowMajorBit : 0
};
@@ -88,15 +88,15 @@
THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE)
}
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index rows() const { return m_matrix.rows() * m_rowFactor.value(); }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index cols() const { return m_matrix.cols() * m_colFactor.value(); }
EIGEN_DEVICE_FUNC
const _MatrixTypeNested& nestedExpression() const
- {
- return m_matrix;
+ {
+ return m_matrix;
}
protected:
@@ -115,7 +115,7 @@
*/
template<typename Derived>
template<int RowFactor, int ColFactor>
-const Replicate<Derived,RowFactor,ColFactor>
+EIGEN_DEVICE_FUNC const Replicate<Derived,RowFactor,ColFactor>
DenseBase<Derived>::replicate() const
{
return Replicate<Derived,RowFactor,ColFactor>(derived());
@@ -130,7 +130,7 @@
* \sa VectorwiseOp::replicate(), DenseBase::replicate(), class Replicate
*/
template<typename ExpressionType, int Direction>
-const typename VectorwiseOp<ExpressionType,Direction>::ReplicateReturnType
+EIGEN_DEVICE_FUNC const typename VectorwiseOp<ExpressionType,Direction>::ReplicateReturnType
VectorwiseOp<ExpressionType,Direction>::replicate(Index factor) const
{
return typename VectorwiseOp<ExpressionType,Direction>::ReplicateReturnType
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reshaped.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reshaped.h
new file mode 100644
index 0000000..52de73b
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reshaped.h
@@ -0,0 +1,454 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2008-2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2014 yoco <peter.xiau@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_RESHAPED_H
+#define EIGEN_RESHAPED_H
+
+namespace Eigen {
+
+/** \class Reshaped
+ * \ingroup Core_Module
+ *
+ * \brief Expression of a fixed-size or dynamic-size reshape
+ *
+ * \tparam XprType the type of the expression in which we are taking a reshape
+ * \tparam Rows the number of rows of the reshape we are taking at compile time (optional)
+ * \tparam Cols the number of columns of the reshape we are taking at compile time (optional)
+ * \tparam Order can be ColMajor or RowMajor, default is ColMajor.
+ *
+ * This class represents an expression of either a fixed-size or dynamic-size reshape.
+ * It is the return type of DenseBase::reshaped(NRowsType,NColsType) and
+ * most of the time this is the only way it is used.
+ *
+ * However, in C++98, if you want to directly maniputate reshaped expressions,
+ * for instance if you want to write a function returning such an expression, you
+ * will need to use this class. In C++11, it is advised to use the \em auto
+ * keyword for such use cases.
+ *
+ * Here is an example illustrating the dynamic case:
+ * \include class_Reshaped.cpp
+ * Output: \verbinclude class_Reshaped.out
+ *
+ * Here is an example illustrating the fixed-size case:
+ * \include class_FixedReshaped.cpp
+ * Output: \verbinclude class_FixedReshaped.out
+ *
+ * \sa DenseBase::reshaped(NRowsType,NColsType)
+ */
+
+namespace internal {
+
+template<typename XprType, int Rows, int Cols, int Order>
+struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
+{
+ typedef typename traits<XprType>::Scalar Scalar;
+ typedef typename traits<XprType>::StorageKind StorageKind;
+ typedef typename traits<XprType>::XprKind XprKind;
+ enum{
+ MatrixRows = traits<XprType>::RowsAtCompileTime,
+ MatrixCols = traits<XprType>::ColsAtCompileTime,
+ RowsAtCompileTime = Rows,
+ ColsAtCompileTime = Cols,
+ MaxRowsAtCompileTime = Rows,
+ MaxColsAtCompileTime = Cols,
+ XpxStorageOrder = ((int(traits<XprType>::Flags) & RowMajorBit) == RowMajorBit) ? RowMajor : ColMajor,
+ ReshapedStorageOrder = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? RowMajor
+ : (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? ColMajor
+ : XpxStorageOrder,
+ HasSameStorageOrderAsXprType = (ReshapedStorageOrder == XpxStorageOrder),
+ InnerSize = (ReshapedStorageOrder==int(RowMajor)) ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
+ InnerStrideAtCompileTime = HasSameStorageOrderAsXprType
+ ? int(inner_stride_at_compile_time<XprType>::ret)
+ : Dynamic,
+ OuterStrideAtCompileTime = Dynamic,
+
+ HasDirectAccess = internal::has_direct_access<XprType>::ret
+ && (Order==int(XpxStorageOrder))
+ && ((evaluator<XprType>::Flags&LinearAccessBit)==LinearAccessBit),
+
+ MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0)
+ && (InnerStrideAtCompileTime == 1)
+ ? PacketAccessBit : 0,
+ //MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0,
+ FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
+ FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
+ FlagsRowMajorBit = (ReshapedStorageOrder==int(RowMajor)) ? RowMajorBit : 0,
+ FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
+ Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit),
+
+ Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit | FlagsDirectAccessBit)
+ };
+};
+
+template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess> class ReshapedImpl_dense;
+
+} // end namespace internal
+
+template<typename XprType, int Rows, int Cols, int Order, typename StorageKind> class ReshapedImpl;
+
+template<typename XprType, int Rows, int Cols, int Order> class Reshaped
+ : public ReshapedImpl<XprType, Rows, Cols, Order, typename internal::traits<XprType>::StorageKind>
+{
+ typedef ReshapedImpl<XprType, Rows, Cols, Order, typename internal::traits<XprType>::StorageKind> Impl;
+ public:
+ //typedef typename Impl::Base Base;
+ typedef Impl Base;
+ EIGEN_GENERIC_PUBLIC_INTERFACE(Reshaped)
+ EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Reshaped)
+
+ /** Fixed-size constructor
+ */
+ EIGEN_DEVICE_FUNC
+ inline Reshaped(XprType& xpr)
+ : Impl(xpr)
+ {
+ EIGEN_STATIC_ASSERT(RowsAtCompileTime!=Dynamic && ColsAtCompileTime!=Dynamic,THIS_METHOD_IS_ONLY_FOR_FIXED_SIZE)
+ eigen_assert(Rows * Cols == xpr.rows() * xpr.cols());
+ }
+
+ /** Dynamic-size constructor
+ */
+ EIGEN_DEVICE_FUNC
+ inline Reshaped(XprType& xpr,
+ Index reshapeRows, Index reshapeCols)
+ : Impl(xpr, reshapeRows, reshapeCols)
+ {
+ eigen_assert((RowsAtCompileTime==Dynamic || RowsAtCompileTime==reshapeRows)
+ && (ColsAtCompileTime==Dynamic || ColsAtCompileTime==reshapeCols));
+ eigen_assert(reshapeRows * reshapeCols == xpr.rows() * xpr.cols());
+ }
+};
+
+// The generic default implementation for dense reshape simply forward to the internal::ReshapedImpl_dense
+// that must be specialized for direct and non-direct access...
+template<typename XprType, int Rows, int Cols, int Order>
+class ReshapedImpl<XprType, Rows, Cols, Order, Dense>
+ : public internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess>
+{
+ typedef internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess> Impl;
+ public:
+ typedef Impl Base;
+ EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl)
+ EIGEN_DEVICE_FUNC inline ReshapedImpl(XprType& xpr) : Impl(xpr) {}
+ EIGEN_DEVICE_FUNC inline ReshapedImpl(XprType& xpr, Index reshapeRows, Index reshapeCols)
+ : Impl(xpr, reshapeRows, reshapeCols) {}
+};
+
+namespace internal {
+
+/** \internal Internal implementation of dense Reshaped in the general case. */
+template<typename XprType, int Rows, int Cols, int Order>
+class ReshapedImpl_dense<XprType,Rows,Cols,Order,false>
+ : public internal::dense_xpr_base<Reshaped<XprType, Rows, Cols, Order> >::type
+{
+ typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
+ public:
+
+ typedef typename internal::dense_xpr_base<ReshapedType>::type Base;
+ EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
+ EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
+
+ typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
+ typedef typename internal::remove_all<XprType>::type NestedExpression;
+
+ class InnerIterator;
+
+ /** Fixed-size constructor
+ */
+ EIGEN_DEVICE_FUNC
+ inline ReshapedImpl_dense(XprType& xpr)
+ : m_xpr(xpr), m_rows(Rows), m_cols(Cols)
+ {}
+
+ /** Dynamic-size constructor
+ */
+ EIGEN_DEVICE_FUNC
+ inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
+ : m_xpr(xpr), m_rows(nRows), m_cols(nCols)
+ {}
+
+ EIGEN_DEVICE_FUNC Index rows() const { return m_rows; }
+ EIGEN_DEVICE_FUNC Index cols() const { return m_cols; }
+
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
+ /** \sa MapBase::data() */
+ EIGEN_DEVICE_FUNC inline const Scalar* data() const;
+ EIGEN_DEVICE_FUNC inline Index innerStride() const;
+ EIGEN_DEVICE_FUNC inline Index outerStride() const;
+ #endif
+
+ /** \returns the nested expression */
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<XprType>::type&
+ nestedExpression() const { return m_xpr; }
+
+ /** \returns the nested expression */
+ EIGEN_DEVICE_FUNC
+ typename internal::remove_reference<XprType>::type&
+ nestedExpression() { return m_xpr; }
+
+ protected:
+
+ MatrixTypeNested m_xpr;
+ const internal::variable_if_dynamic<Index, Rows> m_rows;
+ const internal::variable_if_dynamic<Index, Cols> m_cols;
+};
+
+
+/** \internal Internal implementation of dense Reshaped in the direct access case. */
+template<typename XprType, int Rows, int Cols, int Order>
+class ReshapedImpl_dense<XprType, Rows, Cols, Order, true>
+ : public MapBase<Reshaped<XprType, Rows, Cols, Order> >
+{
+ typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
+ typedef typename internal::ref_selector<XprType>::non_const_type XprTypeNested;
+ public:
+
+ typedef MapBase<ReshapedType> Base;
+ EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
+ EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
+
+ /** Fixed-size constructor
+ */
+ EIGEN_DEVICE_FUNC
+ inline ReshapedImpl_dense(XprType& xpr)
+ : Base(xpr.data()), m_xpr(xpr)
+ {}
+
+ /** Dynamic-size constructor
+ */
+ EIGEN_DEVICE_FUNC
+ inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
+ : Base(xpr.data(), nRows, nCols),
+ m_xpr(xpr)
+ {}
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
+ {
+ return m_xpr;
+ }
+
+ EIGEN_DEVICE_FUNC
+ XprType& nestedExpression() { return m_xpr; }
+
+ /** \sa MapBase::innerStride() */
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const
+ {
+ return m_xpr.innerStride();
+ }
+
+ /** \sa MapBase::outerStride() */
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const
+ {
+ return ((Flags&RowMajorBit)==RowMajorBit) ? this->cols() : this->rows();
+ }
+
+ protected:
+
+ XprTypeNested m_xpr;
+};
+
+// Evaluators
+template<typename ArgType, int Rows, int Cols, int Order, bool HasDirectAccess> struct reshaped_evaluator;
+
+template<typename ArgType, int Rows, int Cols, int Order>
+struct evaluator<Reshaped<ArgType, Rows, Cols, Order> >
+ : reshaped_evaluator<ArgType, Rows, Cols, Order, traits<Reshaped<ArgType,Rows,Cols,Order> >::HasDirectAccess>
+{
+ typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
+ typedef typename XprType::Scalar Scalar;
+ // TODO: should check for smaller packet types
+ typedef typename packet_traits<Scalar>::type PacketScalar;
+
+ enum {
+ CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
+ HasDirectAccess = traits<XprType>::HasDirectAccess,
+
+// RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
+// ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
+// MaxRowsAtCompileTime = traits<XprType>::MaxRowsAtCompileTime,
+// MaxColsAtCompileTime = traits<XprType>::MaxColsAtCompileTime,
+//
+// InnerStrideAtCompileTime = traits<XprType>::HasSameStorageOrderAsXprType
+// ? int(inner_stride_at_compile_time<ArgType>::ret)
+// : Dynamic,
+// OuterStrideAtCompileTime = Dynamic,
+
+ FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1 || HasDirectAccess) ? LinearAccessBit : 0,
+ FlagsRowMajorBit = (traits<XprType>::ReshapedStorageOrder==int(RowMajor)) ? RowMajorBit : 0,
+ FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
+ Flags0 = evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit),
+ Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit | FlagsDirectAccessBit,
+
+ PacketAlignment = unpacket_traits<PacketScalar>::alignment,
+ Alignment = evaluator<ArgType>::Alignment
+ };
+ typedef reshaped_evaluator<ArgType, Rows, Cols, Order, HasDirectAccess> reshaped_evaluator_type;
+ EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : reshaped_evaluator_type(xpr)
+ {
+ EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
+ }
+};
+
+template<typename ArgType, int Rows, int Cols, int Order>
+struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ false>
+ : evaluator_base<Reshaped<ArgType, Rows, Cols, Order> >
+{
+ typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
+
+ enum {
+ CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of index computations */,
+
+ Flags = (evaluator<ArgType>::Flags & (HereditaryBits /*| LinearAccessBit | DirectAccessBit*/)),
+
+ Alignment = 0
+ };
+
+ EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
+ {
+ EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
+ }
+
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+
+ typedef std::pair<Index, Index> RowCol;
+
+ inline RowCol index_remap(Index rowId, Index colId) const
+ {
+ if(Order==ColMajor)
+ {
+ const Index nth_elem_idx = colId * m_xpr.rows() + rowId;
+ return RowCol(nth_elem_idx % m_xpr.nestedExpression().rows(),
+ nth_elem_idx / m_xpr.nestedExpression().rows());
+ }
+ else
+ {
+ const Index nth_elem_idx = colId + rowId * m_xpr.cols();
+ return RowCol(nth_elem_idx / m_xpr.nestedExpression().cols(),
+ nth_elem_idx % m_xpr.nestedExpression().cols());
+ }
+ }
+
+ EIGEN_DEVICE_FUNC
+ inline Scalar& coeffRef(Index rowId, Index colId)
+ {
+ EIGEN_STATIC_ASSERT_LVALUE(XprType)
+ const RowCol row_col = index_remap(rowId, colId);
+ return m_argImpl.coeffRef(row_col.first, row_col.second);
+ }
+
+ EIGEN_DEVICE_FUNC
+ inline const Scalar& coeffRef(Index rowId, Index colId) const
+ {
+ const RowCol row_col = index_remap(rowId, colId);
+ return m_argImpl.coeffRef(row_col.first, row_col.second);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index rowId, Index colId) const
+ {
+ const RowCol row_col = index_remap(rowId, colId);
+ return m_argImpl.coeff(row_col.first, row_col.second);
+ }
+
+ EIGEN_DEVICE_FUNC
+ inline Scalar& coeffRef(Index index)
+ {
+ EIGEN_STATIC_ASSERT_LVALUE(XprType)
+ const RowCol row_col = index_remap(Rows == 1 ? 0 : index,
+ Rows == 1 ? index : 0);
+ return m_argImpl.coeffRef(row_col.first, row_col.second);
+
+ }
+
+ EIGEN_DEVICE_FUNC
+ inline const Scalar& coeffRef(Index index) const
+ {
+ const RowCol row_col = index_remap(Rows == 1 ? 0 : index,
+ Rows == 1 ? index : 0);
+ return m_argImpl.coeffRef(row_col.first, row_col.second);
+ }
+
+ EIGEN_DEVICE_FUNC
+ inline const CoeffReturnType coeff(Index index) const
+ {
+ const RowCol row_col = index_remap(Rows == 1 ? 0 : index,
+ Rows == 1 ? index : 0);
+ return m_argImpl.coeff(row_col.first, row_col.second);
+ }
+#if 0
+ EIGEN_DEVICE_FUNC
+ template<int LoadMode>
+ inline PacketScalar packet(Index rowId, Index colId) const
+ {
+ const RowCol row_col = index_remap(rowId, colId);
+ return m_argImpl.template packet<Unaligned>(row_col.first, row_col.second);
+
+ }
+
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC
+ inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
+ {
+ const RowCol row_col = index_remap(rowId, colId);
+ m_argImpl.const_cast_derived().template writePacket<Unaligned>
+ (row_col.first, row_col.second, val);
+ }
+
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC
+ inline PacketScalar packet(Index index) const
+ {
+ const RowCol row_col = index_remap(RowsAtCompileTime == 1 ? 0 : index,
+ RowsAtCompileTime == 1 ? index : 0);
+ return m_argImpl.template packet<Unaligned>(row_col.first, row_col.second);
+ }
+
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC
+ inline void writePacket(Index index, const PacketScalar& val)
+ {
+ const RowCol row_col = index_remap(RowsAtCompileTime == 1 ? 0 : index,
+ RowsAtCompileTime == 1 ? index : 0);
+ return m_argImpl.template packet<Unaligned>(row_col.first, row_col.second, val);
+ }
+#endif
+protected:
+
+ evaluator<ArgType> m_argImpl;
+ const XprType& m_xpr;
+
+};
+
+template<typename ArgType, int Rows, int Cols, int Order>
+struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ true>
+: mapbase_evaluator<Reshaped<ArgType, Rows, Cols, Order>,
+ typename Reshaped<ArgType, Rows, Cols, Order>::PlainObject>
+{
+ typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
+ typedef typename XprType::Scalar Scalar;
+
+ EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr)
+ : mapbase_evaluator<XprType, typename XprType::PlainObject>(xpr)
+ {
+ // TODO: for the 3.4 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime
+ eigen_assert(((internal::UIntPtr(xpr.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned");
+ }
+};
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_RESHAPED_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ReturnByValue.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ReturnByValue.h
index c44b767..4dad13e 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ReturnByValue.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/ReturnByValue.h
@@ -60,8 +60,10 @@
EIGEN_DEVICE_FUNC
inline void evalTo(Dest& dst) const
{ static_cast<const Derived*>(this)->evalTo(dst); }
- EIGEN_DEVICE_FUNC inline Index rows() const { return static_cast<const Derived*>(this)->rows(); }
- EIGEN_DEVICE_FUNC inline Index cols() const { return static_cast<const Derived*>(this)->cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return static_cast<const Derived*>(this)->rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return static_cast<const Derived*>(this)->cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN
#define Unusable YOU_ARE_TRYING_TO_ACCESS_A_SINGLE_COEFFICIENT_IN_A_SPECIAL_EXPRESSION_WHERE_THAT_IS_NOT_ALLOWED_BECAUSE_THAT_WOULD_BE_INEFFICIENT
@@ -79,7 +81,7 @@
template<typename Derived>
template<typename OtherDerived>
-Derived& DenseBase<Derived>::operator=(const ReturnByValue<OtherDerived>& other)
+EIGEN_DEVICE_FUNC Derived& DenseBase<Derived>::operator=(const ReturnByValue<OtherDerived>& other)
{
other.evalTo(derived());
return derived();
@@ -90,7 +92,7 @@
// Expression is evaluated in a temporary; default implementation of Assignment is bypassed so that
// when a ReturnByValue expression is assigned, the evaluator is not constructed.
// TODO: Finalize port to new regime; ReturnByValue should not exist in the expression world
-
+
template<typename Derived>
struct evaluator<ReturnByValue<Derived> >
: public evaluator<typename internal::traits<Derived>::ReturnType>
@@ -98,7 +100,7 @@
typedef ReturnByValue<Derived> XprType;
typedef typename internal::traits<Derived>::ReturnType PlainObject;
typedef evaluator<PlainObject> Base;
-
+
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reverse.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reverse.h
index 0640cda..28cdd76 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reverse.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Reverse.h
@@ -12,7 +12,7 @@
#ifndef EIGEN_REVERSE_H
#define EIGEN_REVERSE_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
@@ -44,7 +44,7 @@
static inline PacketType run(const PacketType& x) { return x; }
};
-} // end namespace internal
+} // end namespace internal
/** \class Reverse
* \ingroup Core_Module
@@ -89,8 +89,10 @@
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Reverse)
- EIGEN_DEVICE_FUNC inline Index rows() const { return m_matrix.rows(); }
- EIGEN_DEVICE_FUNC inline Index cols() const { return m_matrix.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
EIGEN_DEVICE_FUNC inline Index innerStride() const
{
@@ -98,7 +100,7 @@
}
EIGEN_DEVICE_FUNC const typename internal::remove_all<typename MatrixType::Nested>::type&
- nestedExpression() const
+ nestedExpression() const
{
return m_matrix;
}
@@ -114,7 +116,7 @@
*
*/
template<typename Derived>
-inline typename DenseBase<Derived>::ReverseReturnType
+EIGEN_DEVICE_FUNC inline typename DenseBase<Derived>::ReverseReturnType
DenseBase<Derived>::reverse()
{
return ReverseReturnType(derived());
@@ -136,7 +138,7 @@
*
* \sa VectorwiseOp::reverseInPlace(), reverse() */
template<typename Derived>
-inline void DenseBase<Derived>::reverseInPlace()
+EIGEN_DEVICE_FUNC inline void DenseBase<Derived>::reverseInPlace()
{
if(cols()>rows())
{
@@ -161,7 +163,7 @@
}
namespace internal {
-
+
template<int Direction>
struct vectorwise_reverse_inplace_impl;
@@ -171,8 +173,10 @@
template<typename ExpressionType>
static void run(ExpressionType &xpr)
{
+ const int HalfAtCompileTime = ExpressionType::RowsAtCompileTime==Dynamic?Dynamic:ExpressionType::RowsAtCompileTime/2;
Index half = xpr.rows()/2;
- xpr.topRows(half).swap(xpr.bottomRows(half).colwise().reverse());
+ xpr.topRows(fix<HalfAtCompileTime>(half))
+ .swap(xpr.bottomRows(fix<HalfAtCompileTime>(half)).colwise().reverse());
}
};
@@ -182,8 +186,10 @@
template<typename ExpressionType>
static void run(ExpressionType &xpr)
{
+ const int HalfAtCompileTime = ExpressionType::ColsAtCompileTime==Dynamic?Dynamic:ExpressionType::ColsAtCompileTime/2;
Index half = xpr.cols()/2;
- xpr.leftCols(half).swap(xpr.rightCols(half).rowwise().reverse());
+ xpr.leftCols(fix<HalfAtCompileTime>(half))
+ .swap(xpr.rightCols(fix<HalfAtCompileTime>(half)).rowwise().reverse());
}
};
@@ -201,9 +207,9 @@
*
* \sa DenseBase::reverseInPlace(), reverse() */
template<typename ExpressionType, int Direction>
-void VectorwiseOp<ExpressionType,Direction>::reverseInPlace()
+EIGEN_DEVICE_FUNC void VectorwiseOp<ExpressionType,Direction>::reverseInPlace()
{
- internal::vectorwise_reverse_inplace_impl<Direction>::run(_expression().const_cast_derived());
+ internal::vectorwise_reverse_inplace_impl<Direction>::run(m_matrix);
}
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Select.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Select.h
index 79eec1b..7c86bf8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Select.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Select.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_SELECT_H
#define EIGEN_SELECT_H
-namespace Eigen {
+namespace Eigen {
/** \class Select
* \ingroup Core_Module
@@ -67,8 +67,10 @@
eigen_assert(m_condition.cols() == m_then.cols() && m_condition.cols() == m_else.cols());
}
- inline EIGEN_DEVICE_FUNC Index rows() const { return m_condition.rows(); }
- inline EIGEN_DEVICE_FUNC Index cols() const { return m_condition.cols(); }
+ inline EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_condition.rows(); }
+ inline EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_condition.cols(); }
inline EIGEN_DEVICE_FUNC
const Scalar coeff(Index i, Index j) const
@@ -120,7 +122,7 @@
*/
template<typename Derived>
template<typename ThenDerived,typename ElseDerived>
-inline const Select<Derived,ThenDerived,ElseDerived>
+inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived,ElseDerived>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const DenseBase<ElseDerived>& elseMatrix) const
{
@@ -134,7 +136,7 @@
*/
template<typename Derived>
template<typename ThenDerived>
-inline const Select<Derived,ThenDerived, typename ThenDerived::ConstantReturnType>
+inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived, typename ThenDerived::ConstantReturnType>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const typename ThenDerived::Scalar& elseScalar) const
{
@@ -149,7 +151,7 @@
*/
template<typename Derived>
template<typename ElseDerived>
-inline const Select<Derived, typename ElseDerived::ConstantReturnType, ElseDerived >
+inline EIGEN_DEVICE_FUNC const Select<Derived, typename ElseDerived::ConstantReturnType, ElseDerived >
DenseBase<Derived>::select(const typename ElseDerived::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SelfAdjointView.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SelfAdjointView.h
index b2e51f3..8ce3b37 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SelfAdjointView.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SelfAdjointView.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_SELFADJOINTMATRIX_H
#define EIGEN_SELFADJOINTMATRIX_H
-namespace Eigen {
+namespace Eigen {
/** \class SelfAdjointView
* \ingroup Core_Module
@@ -58,14 +58,15 @@
typedef MatrixTypeNestedCleaned NestedExpression;
/** \brief The type of coefficients in this matrix */
- typedef typename internal::traits<SelfAdjointView>::Scalar Scalar;
+ typedef typename internal::traits<SelfAdjointView>::Scalar Scalar;
typedef typename MatrixType::StorageIndex StorageIndex;
typedef typename internal::remove_all<typename MatrixType::ConjugateReturnType>::type MatrixConjugateReturnType;
+ typedef SelfAdjointView<typename internal::add_const<MatrixType>::type, UpLo> ConstSelfAdjointView;
enum {
Mode = internal::traits<SelfAdjointView>::Mode,
Flags = internal::traits<SelfAdjointView>::Flags,
- TransposeMode = ((Mode & Upper) ? Lower : 0) | ((Mode & Lower) ? Upper : 0)
+ TransposeMode = ((int(Mode) & int(Upper)) ? Lower : 0) | ((int(Mode) & int(Lower)) ? Upper : 0)
};
typedef typename MatrixType::PlainObject PlainObject;
@@ -75,14 +76,14 @@
EIGEN_STATIC_ASSERT(UpLo==Lower || UpLo==Upper,SELFADJOINTVIEW_ACCEPTS_UPPER_AND_LOWER_MODE_ONLY);
}
- EIGEN_DEVICE_FUNC
- inline Index rows() const { return m_matrix.rows(); }
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return m_matrix.cols(); }
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const { return m_matrix.outerStride(); }
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const { return m_matrix.innerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return m_matrix.outerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT { return m_matrix.innerStride(); }
/** \sa MatrixBase::coeff()
* \warning the coordinates must fit into the referenced triangular part
@@ -131,7 +132,7 @@
{
return Product<OtherDerived,SelfAdjointView>(lhs.derived(),rhs);
}
-
+
friend EIGEN_DEVICE_FUNC
const SelfAdjointView<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,MatrixType,product),UpLo>
operator*(const Scalar& s, const SelfAdjointView& mat)
@@ -197,6 +198,18 @@
inline const ConjugateReturnType conjugate() const
{ return ConjugateReturnType(m_matrix.conjugate()); }
+ /** \returns an expression of the complex conjugate of \c *this if Cond==true,
+ * returns \c *this otherwise.
+ */
+ template<bool Cond>
+ EIGEN_DEVICE_FUNC
+ inline typename internal::conditional<Cond,ConjugateReturnType,ConstSelfAdjointView>::type
+ conjugateIf() const
+ {
+ typedef typename internal::conditional<Cond,ConjugateReturnType,ConstSelfAdjointView>::type ReturnType;
+ return ReturnType(m_matrix.template conjugateIf<Cond>());
+ }
+
typedef SelfAdjointView<const typename MatrixType::AdjointReturnType,TransposeMode> AdjointReturnType;
/** \sa MatrixBase::adjoint() const */
EIGEN_DEVICE_FUNC
@@ -287,17 +300,17 @@
using Base::m_src;
using Base::m_functor;
public:
-
+
typedef typename Base::DstEvaluatorType DstEvaluatorType;
typedef typename Base::SrcEvaluatorType SrcEvaluatorType;
typedef typename Base::Scalar Scalar;
typedef typename Base::AssignmentTraits AssignmentTraits;
-
-
+
+
EIGEN_DEVICE_FUNC triangular_dense_assignment_kernel(DstEvaluatorType &dst, const SrcEvaluatorType &src, const Functor &func, DstXprType& dstExpr)
: Base(dst, src, func, dstExpr)
{}
-
+
EIGEN_DEVICE_FUNC void assignCoeff(Index row, Index col)
{
eigen_internal_assert(row!=col);
@@ -305,12 +318,12 @@
m_functor.assignCoeff(m_dst.coeffRef(row,col), tmp);
m_functor.assignCoeff(m_dst.coeffRef(col,row), numext::conj(tmp));
}
-
+
EIGEN_DEVICE_FUNC void assignDiagonalCoeff(Index id)
{
Base::assignCoeff(id,id);
}
-
+
EIGEN_DEVICE_FUNC void assignOppositeCoeff(Index, Index)
{ eigen_internal_assert(false && "should never be called"); }
};
@@ -324,7 +337,7 @@
/** This is the const version of MatrixBase::selfadjointView() */
template<typename Derived>
template<unsigned int UpLo>
-typename MatrixBase<Derived>::template ConstSelfAdjointViewReturnType<UpLo>::Type
+EIGEN_DEVICE_FUNC typename MatrixBase<Derived>::template ConstSelfAdjointViewReturnType<UpLo>::Type
MatrixBase<Derived>::selfadjointView() const
{
return typename ConstSelfAdjointViewReturnType<UpLo>::Type(derived());
@@ -341,7 +354,7 @@
*/
template<typename Derived>
template<unsigned int UpLo>
-typename MatrixBase<Derived>::template SelfAdjointViewReturnType<UpLo>::Type
+EIGEN_DEVICE_FUNC typename MatrixBase<Derived>::template SelfAdjointViewReturnType<UpLo>::Type
MatrixBase<Derived>::selfadjointView()
{
return typename SelfAdjointViewReturnType<UpLo>::Type(derived());
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Solve.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Solve.h
index a8daea5..23d5cb7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Solve.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Solve.h
@@ -13,13 +13,13 @@
namespace Eigen {
template<typename Decomposition, typename RhsType, typename StorageKind> class SolveImpl;
-
+
/** \class Solve
* \ingroup Core_Module
*
* \brief Pseudo expression representing a solving operation
*
- * \tparam Decomposition the type of the matrix or decomposion object
+ * \tparam Decomposition the type of the matrix or decomposition object
* \tparam Rhstype the type of the right-hand side
*
* This class represents an expression of A.solve(B)
@@ -64,13 +64,13 @@
public:
typedef typename internal::traits<Solve>::PlainObject PlainObject;
typedef typename internal::traits<Solve>::StorageIndex StorageIndex;
-
+
Solve(const Decomposition &dec, const RhsType &rhs)
: m_dec(dec), m_rhs(rhs)
{}
-
- EIGEN_DEVICE_FUNC Index rows() const { return m_dec.cols(); }
- EIGEN_DEVICE_FUNC Index cols() const { return m_rhs.cols(); }
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_dec.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); }
EIGEN_DEVICE_FUNC const Decomposition& dec() const { return m_dec; }
EIGEN_DEVICE_FUNC const RhsType& rhs() const { return m_rhs; }
@@ -87,14 +87,14 @@
: public MatrixBase<Solve<Decomposition,RhsType> >
{
typedef Solve<Decomposition,RhsType> Derived;
-
+
public:
-
+
typedef MatrixBase<Solve<Decomposition,RhsType> > Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
private:
-
+
Scalar coeff(Index row, Index col) const;
Scalar coeff(Index i) const;
};
@@ -119,15 +119,15 @@
typedef evaluator<PlainObject> Base;
enum { Flags = Base::Flags | EvalBeforeNestingBit };
-
+
EIGEN_DEVICE_FUNC explicit evaluator(const SolveType& solve)
: m_result(solve.rows(), solve.cols())
{
::new (static_cast<Base*>(this)) Base(m_result);
solve.dec()._solve_impl(solve.rhs(), m_result);
}
-
-protected:
+
+protected:
PlainObject m_result;
};
@@ -176,12 +176,12 @@
Index dstCols = src.cols();
if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
dst.resize(dstRows, dstCols);
-
+
src.dec().nestedExpression().nestedExpression().template _solve_impl_transposed<true>(src.rhs(), dst);
}
};
-} // end namepsace internal
+} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolveTriangular.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolveTriangular.h
index 4652e2e..dfbf995 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolveTriangular.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolveTriangular.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_SOLVETRIANGULAR_H
#define EIGEN_SOLVETRIANGULAR_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
@@ -19,7 +19,7 @@
template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector;
-template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
+template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder, int OtherInnerStride>
struct triangular_solve_matrix;
// small helper struct extracting some traits on the underlying solver operation
@@ -54,7 +54,7 @@
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
@@ -64,7 +64,7 @@
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(),
(useRhsDirectly ? rhs.data() : 0));
-
+
if(!useRhsDirectly)
MappedRhs(actualRhs,rhs.size()) = rhs;
@@ -85,7 +85,7 @@
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
@@ -98,8 +98,8 @@
BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false);
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
- (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
- ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking);
+ (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor, Rhs::InnerStrideAtCompileTime>
+ ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.innerStride(), rhs.outerStride(), blocking);
}
};
@@ -118,7 +118,7 @@
DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1,
StartIndex = IsLower ? 0 : DiagIndex+1
};
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
if (LoopIndex>0)
rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex).template segment<LoopIndex>(StartIndex).transpose()
@@ -133,22 +133,22 @@
template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size>
struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,true> {
- static void run(const Lhs&, Rhs&) {}
+ static EIGEN_DEVICE_FUNC void run(const Lhs&, Rhs&) {}
};
template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
};
template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
Transpose<const Lhs> trLhs(lhs);
Transpose<Rhs> trRhs(rhs);
-
+
triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>,
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs);
@@ -164,11 +164,11 @@
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename MatrixType, unsigned int Mode>
template<int Side, typename OtherDerived>
-void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
+EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
{
OtherDerived& other = _other.const_cast_derived();
eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) );
- eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
+ eigen_assert((!(int(Mode) & int(ZeroDiag))) && bool(int(Mode) & (int(Upper) | int(Lower))));
// If solving for a 0x0 matrix, nothing to do, simply return.
if (derived().cols() == 0)
return;
@@ -213,8 +213,8 @@
: m_triangularMatrix(tri), m_rhs(rhs)
{}
- inline Index rows() const { return m_rhs.rows(); }
- inline Index cols() const { return m_rhs.cols(); }
+ inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_rhs.rows(); }
+ inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolverBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolverBase.h
index 8a4adc2..5014610 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolverBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/SolverBase.h
@@ -14,8 +14,35 @@
namespace internal {
+template<typename Derived>
+struct solve_assertion {
+ template<bool Transpose_, typename Rhs>
+ static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); }
+};
+template<typename Derived>
+struct solve_assertion<Transpose<Derived> >
+{
+ typedef Transpose<Derived> type;
+ template<bool Transpose_, typename Rhs>
+ static void run(const type& transpose, const Rhs& b)
+ {
+ internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b);
+ }
+};
+
+template<typename Scalar, typename Derived>
+struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > >
+{
+ typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type;
+
+ template<bool Transpose_, typename Rhs>
+ static void run(const type& adjoint, const Rhs& b)
+ {
+ internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b);
+ }
+};
} // end namespace internal
/** \class SolverBase
@@ -35,7 +62,7 @@
*
* \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors.
*
- * \sa class PartialPivLU, class FullPivLU
+ * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase
*/
template<typename Derived>
class SolverBase : public EigenBase<Derived>
@@ -46,6 +73,9 @@
typedef typename internal::traits<Derived>::Scalar Scalar;
typedef Scalar CoeffReturnType;
+ template<typename Derived_>
+ friend struct internal::solve_assertion;
+
enum {
RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
@@ -56,7 +86,8 @@
MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
internal::traits<Derived>::MaxColsAtCompileTime>::ret),
IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
- || internal::traits<Derived>::MaxColsAtCompileTime == 1
+ || internal::traits<Derived>::MaxColsAtCompileTime == 1,
+ NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2
};
/** Default constructor */
@@ -74,7 +105,7 @@
inline const Solve<Derived, Rhs>
solve(const MatrixBase<Rhs>& b) const
{
- eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
+ internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b);
return Solve<Derived, Rhs>(derived(), b.derived());
}
@@ -112,6 +143,13 @@
}
protected:
+
+ template<bool Transpose_, typename Rhs>
+ void _check_solve_assertion(const Rhs& b) const {
+ EIGEN_ONLY_USED_FOR_DEBUG(b);
+ eigen_assert(derived().m_isInitialized && "Solver is not initialized.");
+ eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b");
+ }
};
namespace internal {
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StableNorm.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StableNorm.h
index 88c8d98..4a3f0cc 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StableNorm.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StableNorm.h
@@ -50,6 +50,71 @@
ssq += (bl*invScale).squaredNorm();
}
+template<typename VectorType, typename RealScalar>
+void stable_norm_impl_inner_step(const VectorType &vec, RealScalar& ssq, RealScalar& scale, RealScalar& invScale)
+{
+ typedef typename VectorType::Scalar Scalar;
+ const Index blockSize = 4096;
+
+ typedef typename internal::nested_eval<VectorType,2>::type VectorTypeCopy;
+ typedef typename internal::remove_all<VectorTypeCopy>::type VectorTypeCopyClean;
+ const VectorTypeCopy copy(vec);
+
+ enum {
+ CanAlign = ( (int(VectorTypeCopyClean::Flags)&DirectAccessBit)
+ || (int(internal::evaluator<VectorTypeCopyClean>::Alignment)>0) // FIXME Alignment)>0 might not be enough
+ ) && (blockSize*sizeof(Scalar)*2<EIGEN_STACK_ALLOCATION_LIMIT)
+ && (EIGEN_MAX_STATIC_ALIGN_BYTES>0) // if we cannot allocate on the stack, then let's not bother about this optimization
+ };
+ typedef typename internal::conditional<CanAlign, Ref<const Matrix<Scalar,Dynamic,1,0,blockSize,1>, internal::evaluator<VectorTypeCopyClean>::Alignment>,
+ typename VectorTypeCopyClean::ConstSegmentReturnType>::type SegmentWrapper;
+ Index n = vec.size();
+
+ Index bi = internal::first_default_aligned(copy);
+ if (bi>0)
+ internal::stable_norm_kernel(copy.head(bi), ssq, scale, invScale);
+ for (; bi<n; bi+=blockSize)
+ internal::stable_norm_kernel(SegmentWrapper(copy.segment(bi,numext::mini(blockSize, n - bi))), ssq, scale, invScale);
+}
+
+template<typename VectorType>
+typename VectorType::RealScalar
+stable_norm_impl(const VectorType &vec, typename enable_if<VectorType::IsVectorAtCompileTime>::type* = 0 )
+{
+ using std::sqrt;
+ using std::abs;
+
+ Index n = vec.size();
+
+ if(n==1)
+ return abs(vec.coeff(0));
+
+ typedef typename VectorType::RealScalar RealScalar;
+ RealScalar scale(0);
+ RealScalar invScale(1);
+ RealScalar ssq(0); // sum of squares
+
+ stable_norm_impl_inner_step(vec, ssq, scale, invScale);
+
+ return scale * sqrt(ssq);
+}
+
+template<typename MatrixType>
+typename MatrixType::RealScalar
+stable_norm_impl(const MatrixType &mat, typename enable_if<!MatrixType::IsVectorAtCompileTime>::type* = 0 )
+{
+ using std::sqrt;
+
+ typedef typename MatrixType::RealScalar RealScalar;
+ RealScalar scale(0);
+ RealScalar invScale(1);
+ RealScalar ssq(0); // sum of squares
+
+ for(Index j=0; j<mat.outerSize(); ++j)
+ stable_norm_impl_inner_step(mat.innerVector(j), ssq, scale, invScale);
+ return scale * sqrt(ssq);
+}
+
template<typename Derived>
inline typename NumTraits<typename traits<Derived>::Scalar>::Real
blueNorm_impl(const EigenBase<Derived>& _vec)
@@ -58,52 +123,43 @@
using std::pow;
using std::sqrt;
using std::abs;
+
+ // This program calculates the machine-dependent constants
+ // bl, b2, slm, s2m, relerr overfl
+ // from the "basic" machine-dependent numbers
+ // nbig, ibeta, it, iemin, iemax, rbig.
+ // The following define the basic machine-dependent constants.
+ // For portability, the PORT subprograms "ilmaeh" and "rlmach"
+ // are used. For any specific computer, each of the assignment
+ // statements can be replaced
+ static const int ibeta = std::numeric_limits<RealScalar>::radix; // base for floating-point numbers
+ static const int it = NumTraits<RealScalar>::digits(); // number of base-beta digits in mantissa
+ static const int iemin = NumTraits<RealScalar>::min_exponent(); // minimum exponent
+ static const int iemax = NumTraits<RealScalar>::max_exponent(); // maximum exponent
+ static const RealScalar rbig = NumTraits<RealScalar>::highest(); // largest floating-point number
+ static const RealScalar b1 = RealScalar(pow(RealScalar(ibeta),RealScalar(-((1-iemin)/2)))); // lower boundary of midrange
+ static const RealScalar b2 = RealScalar(pow(RealScalar(ibeta),RealScalar((iemax + 1 - it)/2))); // upper boundary of midrange
+ static const RealScalar s1m = RealScalar(pow(RealScalar(ibeta),RealScalar((2-iemin)/2))); // scaling factor for lower range
+ static const RealScalar s2m = RealScalar(pow(RealScalar(ibeta),RealScalar(- ((iemax+it)/2)))); // scaling factor for upper range
+ static const RealScalar eps = RealScalar(pow(double(ibeta), 1-it));
+ static const RealScalar relerr = sqrt(eps); // tolerance for neglecting asml
+
const Derived& vec(_vec.derived());
- static bool initialized = false;
- static RealScalar b1, b2, s1m, s2m, rbig, relerr;
- if(!initialized)
- {
- int ibeta, it, iemin, iemax, iexp;
- RealScalar eps;
- // This program calculates the machine-dependent constants
- // bl, b2, slm, s2m, relerr overfl
- // from the "basic" machine-dependent numbers
- // nbig, ibeta, it, iemin, iemax, rbig.
- // The following define the basic machine-dependent constants.
- // For portability, the PORT subprograms "ilmaeh" and "rlmach"
- // are used. For any specific computer, each of the assignment
- // statements can be replaced
- ibeta = std::numeric_limits<RealScalar>::radix; // base for floating-point numbers
- it = std::numeric_limits<RealScalar>::digits; // number of base-beta digits in mantissa
- iemin = std::numeric_limits<RealScalar>::min_exponent; // minimum exponent
- iemax = std::numeric_limits<RealScalar>::max_exponent; // maximum exponent
- rbig = (std::numeric_limits<RealScalar>::max)(); // largest floating-point number
-
- iexp = -((1-iemin)/2);
- b1 = RealScalar(pow(RealScalar(ibeta),RealScalar(iexp))); // lower boundary of midrange
- iexp = (iemax + 1 - it)/2;
- b2 = RealScalar(pow(RealScalar(ibeta),RealScalar(iexp))); // upper boundary of midrange
-
- iexp = (2-iemin)/2;
- s1m = RealScalar(pow(RealScalar(ibeta),RealScalar(iexp))); // scaling factor for lower range
- iexp = - ((iemax+it)/2);
- s2m = RealScalar(pow(RealScalar(ibeta),RealScalar(iexp))); // scaling factor for upper range
-
- eps = RealScalar(pow(double(ibeta), 1-it));
- relerr = sqrt(eps); // tolerance for neglecting asml
- initialized = true;
- }
Index n = vec.size();
RealScalar ab2 = b2 / RealScalar(n);
RealScalar asml = RealScalar(0);
RealScalar amed = RealScalar(0);
RealScalar abig = RealScalar(0);
- for(typename Derived::InnerIterator it(vec, 0); it; ++it)
+
+ for(Index j=0; j<vec.outerSize(); ++j)
{
- RealScalar ax = abs(it.value());
- if(ax > ab2) abig += numext::abs2(ax*s2m);
- else if(ax < b1) asml += numext::abs2(ax*s1m);
- else amed += numext::abs2(ax);
+ for(typename Derived::InnerIterator iter(vec, j); iter; ++iter)
+ {
+ RealScalar ax = abs(iter.value());
+ if(ax > ab2) abig += numext::abs2(ax*s2m);
+ else if(ax < b1) asml += numext::abs2(ax*s1m);
+ else amed += numext::abs2(ax);
+ }
}
if(amed!=amed)
return amed; // we got a NaN
@@ -156,36 +212,7 @@
inline typename NumTraits<typename internal::traits<Derived>::Scalar>::Real
MatrixBase<Derived>::stableNorm() const
{
- using std::sqrt;
- using std::abs;
- const Index blockSize = 4096;
- RealScalar scale(0);
- RealScalar invScale(1);
- RealScalar ssq(0); // sum of square
-
- typedef typename internal::nested_eval<Derived,2>::type DerivedCopy;
- typedef typename internal::remove_all<DerivedCopy>::type DerivedCopyClean;
- const DerivedCopy copy(derived());
-
- enum {
- CanAlign = ( (int(DerivedCopyClean::Flags)&DirectAccessBit)
- || (int(internal::evaluator<DerivedCopyClean>::Alignment)>0) // FIXME Alignment)>0 might not be enough
- ) && (blockSize*sizeof(Scalar)*2<EIGEN_STACK_ALLOCATION_LIMIT)
- && (EIGEN_MAX_STATIC_ALIGN_BYTES>0) // if we cannot allocate on the stack, then let's not bother about this optimization
- };
- typedef typename internal::conditional<CanAlign, Ref<const Matrix<Scalar,Dynamic,1,0,blockSize,1>, internal::evaluator<DerivedCopyClean>::Alignment>,
- typename DerivedCopyClean::ConstSegmentReturnType>::type SegmentWrapper;
- Index n = size();
-
- if(n==1)
- return abs(this->coeff(0));
-
- Index bi = internal::first_default_aligned(copy);
- if (bi>0)
- internal::stable_norm_kernel(copy.head(bi), ssq, scale, invScale);
- for (; bi<n; bi+=blockSize)
- internal::stable_norm_kernel(SegmentWrapper(copy.segment(bi,numext::mini(blockSize, n - bi))), ssq, scale, invScale);
- return scale * sqrt(ssq);
+ return internal::stable_norm_impl(derived());
}
/** \returns the \em l2 norm of \c *this using the Blue's algorithm.
@@ -213,7 +240,10 @@
inline typename NumTraits<typename internal::traits<Derived>::Scalar>::Real
MatrixBase<Derived>::hypotNorm() const
{
- return this->cwiseAbs().redux(internal::scalar_hypot_op<RealScalar>());
+ if(size()==1)
+ return numext::abs(coeff(0,0));
+ else
+ return this->cwiseAbs().redux(internal::scalar_hypot_op<RealScalar>());
}
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StlIterators.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StlIterators.h
new file mode 100644
index 0000000..09041db
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/StlIterators.h
@@ -0,0 +1,463 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_STLITERATORS_H
+#define EIGEN_STLITERATORS_H
+
+namespace Eigen {
+
+namespace internal {
+
+template<typename IteratorType>
+struct indexed_based_stl_iterator_traits;
+
+template<typename Derived>
+class indexed_based_stl_iterator_base
+{
+protected:
+ typedef indexed_based_stl_iterator_traits<Derived> traits;
+ typedef typename traits::XprType XprType;
+ typedef indexed_based_stl_iterator_base<typename traits::non_const_iterator> non_const_iterator;
+ typedef indexed_based_stl_iterator_base<typename traits::const_iterator> const_iterator;
+ typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ // NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
+ friend class indexed_based_stl_iterator_base<typename traits::const_iterator>;
+ friend class indexed_based_stl_iterator_base<typename traits::non_const_iterator>;
+public:
+ typedef Index difference_type;
+ typedef std::random_access_iterator_tag iterator_category;
+
+ indexed_based_stl_iterator_base() EIGEN_NO_THROW : mp_xpr(0), m_index(0) {}
+ indexed_based_stl_iterator_base(XprType& xpr, Index index) EIGEN_NO_THROW : mp_xpr(&xpr), m_index(index) {}
+
+ indexed_based_stl_iterator_base(const non_const_iterator& other) EIGEN_NO_THROW
+ : mp_xpr(other.mp_xpr), m_index(other.m_index)
+ {}
+
+ indexed_based_stl_iterator_base& operator=(const non_const_iterator& other)
+ {
+ mp_xpr = other.mp_xpr;
+ m_index = other.m_index;
+ return *this;
+ }
+
+ Derived& operator++() { ++m_index; return derived(); }
+ Derived& operator--() { --m_index; return derived(); }
+
+ Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
+ Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
+
+ friend Derived operator+(const indexed_based_stl_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; }
+ friend Derived operator-(const indexed_based_stl_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; }
+ friend Derived operator+(Index a, const indexed_based_stl_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; }
+ friend Derived operator-(Index a, const indexed_based_stl_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; }
+
+ Derived& operator+=(Index b) { m_index += b; return derived(); }
+ Derived& operator-=(Index b) { m_index -= b; return derived(); }
+
+ difference_type operator-(const indexed_based_stl_iterator_base& other) const
+ {
+ eigen_assert(mp_xpr == other.mp_xpr);
+ return m_index - other.m_index;
+ }
+
+ difference_type operator-(const other_iterator& other) const
+ {
+ eigen_assert(mp_xpr == other.mp_xpr);
+ return m_index - other.m_index;
+ }
+
+ bool operator==(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
+ bool operator!=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
+ bool operator< (const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
+ bool operator<=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
+ bool operator> (const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
+ bool operator>=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
+
+ bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
+ bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
+ bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
+ bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
+ bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
+ bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
+
+protected:
+
+ Derived& derived() { return static_cast<Derived&>(*this); }
+ const Derived& derived() const { return static_cast<const Derived&>(*this); }
+
+ XprType *mp_xpr;
+ Index m_index;
+};
+
+template<typename Derived>
+class indexed_based_stl_reverse_iterator_base
+{
+protected:
+ typedef indexed_based_stl_iterator_traits<Derived> traits;
+ typedef typename traits::XprType XprType;
+ typedef indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator> non_const_iterator;
+ typedef indexed_based_stl_reverse_iterator_base<typename traits::const_iterator> const_iterator;
+ typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ // NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
+ friend class indexed_based_stl_reverse_iterator_base<typename traits::const_iterator>;
+ friend class indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator>;
+public:
+ typedef Index difference_type;
+ typedef std::random_access_iterator_tag iterator_category;
+
+ indexed_based_stl_reverse_iterator_base() : mp_xpr(0), m_index(0) {}
+ indexed_based_stl_reverse_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {}
+
+ indexed_based_stl_reverse_iterator_base(const non_const_iterator& other)
+ : mp_xpr(other.mp_xpr), m_index(other.m_index)
+ {}
+
+ indexed_based_stl_reverse_iterator_base& operator=(const non_const_iterator& other)
+ {
+ mp_xpr = other.mp_xpr;
+ m_index = other.m_index;
+ return *this;
+ }
+
+ Derived& operator++() { --m_index; return derived(); }
+ Derived& operator--() { ++m_index; return derived(); }
+
+ Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
+ Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
+
+ friend Derived operator+(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; }
+ friend Derived operator-(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; }
+ friend Derived operator+(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; }
+ friend Derived operator-(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; }
+
+ Derived& operator+=(Index b) { m_index -= b; return derived(); }
+ Derived& operator-=(Index b) { m_index += b; return derived(); }
+
+ difference_type operator-(const indexed_based_stl_reverse_iterator_base& other) const
+ {
+ eigen_assert(mp_xpr == other.mp_xpr);
+ return other.m_index - m_index;
+ }
+
+ difference_type operator-(const other_iterator& other) const
+ {
+ eigen_assert(mp_xpr == other.mp_xpr);
+ return other.m_index - m_index;
+ }
+
+ bool operator==(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
+ bool operator!=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
+ bool operator< (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
+ bool operator<=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
+ bool operator> (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
+ bool operator>=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
+
+ bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
+ bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
+ bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
+ bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
+ bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
+ bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
+
+protected:
+
+ Derived& derived() { return static_cast<Derived&>(*this); }
+ const Derived& derived() const { return static_cast<const Derived&>(*this); }
+
+ XprType *mp_xpr;
+ Index m_index;
+};
+
+template<typename XprType>
+class pointer_based_stl_iterator
+{
+ enum { is_lvalue = internal::is_lvalue<XprType>::value };
+ typedef pointer_based_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
+ typedef pointer_based_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
+ typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ // NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
+ friend class pointer_based_stl_iterator<typename internal::add_const<XprType>::type>;
+ friend class pointer_based_stl_iterator<typename internal::remove_const<XprType>::type>;
+public:
+ typedef Index difference_type;
+ typedef typename XprType::Scalar value_type;
+ typedef std::random_access_iterator_tag iterator_category;
+ typedef typename internal::conditional<bool(is_lvalue), value_type*, const value_type*>::type pointer;
+ typedef typename internal::conditional<bool(is_lvalue), value_type&, const value_type&>::type reference;
+
+
+ pointer_based_stl_iterator() EIGEN_NO_THROW : m_ptr(0) {}
+ pointer_based_stl_iterator(XprType& xpr, Index index) EIGEN_NO_THROW : m_incr(xpr.innerStride())
+ {
+ m_ptr = xpr.data() + index * m_incr.value();
+ }
+
+ pointer_based_stl_iterator(const non_const_iterator& other) EIGEN_NO_THROW
+ : m_ptr(other.m_ptr), m_incr(other.m_incr)
+ {}
+
+ pointer_based_stl_iterator& operator=(const non_const_iterator& other) EIGEN_NO_THROW
+ {
+ m_ptr = other.m_ptr;
+ m_incr.setValue(other.m_incr);
+ return *this;
+ }
+
+ reference operator*() const { return *m_ptr; }
+ reference operator[](Index i) const { return *(m_ptr+i*m_incr.value()); }
+ pointer operator->() const { return m_ptr; }
+
+ pointer_based_stl_iterator& operator++() { m_ptr += m_incr.value(); return *this; }
+ pointer_based_stl_iterator& operator--() { m_ptr -= m_incr.value(); return *this; }
+
+ pointer_based_stl_iterator operator++(int) { pointer_based_stl_iterator prev(*this); operator++(); return prev;}
+ pointer_based_stl_iterator operator--(int) { pointer_based_stl_iterator prev(*this); operator--(); return prev;}
+
+ friend pointer_based_stl_iterator operator+(const pointer_based_stl_iterator& a, Index b) { pointer_based_stl_iterator ret(a); ret += b; return ret; }
+ friend pointer_based_stl_iterator operator-(const pointer_based_stl_iterator& a, Index b) { pointer_based_stl_iterator ret(a); ret -= b; return ret; }
+ friend pointer_based_stl_iterator operator+(Index a, const pointer_based_stl_iterator& b) { pointer_based_stl_iterator ret(b); ret += a; return ret; }
+ friend pointer_based_stl_iterator operator-(Index a, const pointer_based_stl_iterator& b) { pointer_based_stl_iterator ret(b); ret -= a; return ret; }
+
+ pointer_based_stl_iterator& operator+=(Index b) { m_ptr += b*m_incr.value(); return *this; }
+ pointer_based_stl_iterator& operator-=(Index b) { m_ptr -= b*m_incr.value(); return *this; }
+
+ difference_type operator-(const pointer_based_stl_iterator& other) const {
+ return (m_ptr - other.m_ptr)/m_incr.value();
+ }
+
+ difference_type operator-(const other_iterator& other) const {
+ return (m_ptr - other.m_ptr)/m_incr.value();
+ }
+
+ bool operator==(const pointer_based_stl_iterator& other) const { return m_ptr == other.m_ptr; }
+ bool operator!=(const pointer_based_stl_iterator& other) const { return m_ptr != other.m_ptr; }
+ bool operator< (const pointer_based_stl_iterator& other) const { return m_ptr < other.m_ptr; }
+ bool operator<=(const pointer_based_stl_iterator& other) const { return m_ptr <= other.m_ptr; }
+ bool operator> (const pointer_based_stl_iterator& other) const { return m_ptr > other.m_ptr; }
+ bool operator>=(const pointer_based_stl_iterator& other) const { return m_ptr >= other.m_ptr; }
+
+ bool operator==(const other_iterator& other) const { return m_ptr == other.m_ptr; }
+ bool operator!=(const other_iterator& other) const { return m_ptr != other.m_ptr; }
+ bool operator< (const other_iterator& other) const { return m_ptr < other.m_ptr; }
+ bool operator<=(const other_iterator& other) const { return m_ptr <= other.m_ptr; }
+ bool operator> (const other_iterator& other) const { return m_ptr > other.m_ptr; }
+ bool operator>=(const other_iterator& other) const { return m_ptr >= other.m_ptr; }
+
+protected:
+
+ pointer m_ptr;
+ internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_incr;
+};
+
+template<typename _XprType>
+struct indexed_based_stl_iterator_traits<generic_randaccess_stl_iterator<_XprType> >
+{
+ typedef _XprType XprType;
+ typedef generic_randaccess_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
+ typedef generic_randaccess_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
+};
+
+template<typename XprType>
+class generic_randaccess_stl_iterator : public indexed_based_stl_iterator_base<generic_randaccess_stl_iterator<XprType> >
+{
+public:
+ typedef typename XprType::Scalar value_type;
+
+protected:
+
+ enum {
+ has_direct_access = (internal::traits<XprType>::Flags & DirectAccessBit) ? 1 : 0,
+ is_lvalue = internal::is_lvalue<XprType>::value
+ };
+
+ typedef indexed_based_stl_iterator_base<generic_randaccess_stl_iterator> Base;
+ using Base::m_index;
+ using Base::mp_xpr;
+
+ // TODO currently const Transpose/Reshape expressions never returns const references,
+ // so lets return by value too.
+ //typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t;
+ typedef const value_type read_only_ref_t;
+
+public:
+
+ typedef typename internal::conditional<bool(is_lvalue), value_type *, const value_type *>::type pointer;
+ typedef typename internal::conditional<bool(is_lvalue), value_type&, read_only_ref_t>::type reference;
+
+ generic_randaccess_stl_iterator() : Base() {}
+ generic_randaccess_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
+ generic_randaccess_stl_iterator(const typename Base::non_const_iterator& other) : Base(other) {}
+ using Base::operator=;
+
+ reference operator*() const { return (*mp_xpr)(m_index); }
+ reference operator[](Index i) const { return (*mp_xpr)(m_index+i); }
+ pointer operator->() const { return &((*mp_xpr)(m_index)); }
+};
+
+template<typename _XprType, DirectionType Direction>
+struct indexed_based_stl_iterator_traits<subvector_stl_iterator<_XprType,Direction> >
+{
+ typedef _XprType XprType;
+ typedef subvector_stl_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
+ typedef subvector_stl_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
+};
+
+template<typename XprType, DirectionType Direction>
+class subvector_stl_iterator : public indexed_based_stl_iterator_base<subvector_stl_iterator<XprType,Direction> >
+{
+protected:
+
+ enum { is_lvalue = internal::is_lvalue<XprType>::value };
+
+ typedef indexed_based_stl_iterator_base<subvector_stl_iterator> Base;
+ using Base::m_index;
+ using Base::mp_xpr;
+
+ typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
+ typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
+
+
+public:
+ typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
+ typedef typename reference::PlainObject value_type;
+
+private:
+ class subvector_stl_iterator_ptr
+ {
+ public:
+ subvector_stl_iterator_ptr(const reference &subvector) : m_subvector(subvector) {}
+ reference* operator->() { return &m_subvector; }
+ private:
+ reference m_subvector;
+ };
+public:
+
+ typedef subvector_stl_iterator_ptr pointer;
+
+ subvector_stl_iterator() : Base() {}
+ subvector_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
+
+ reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); }
+ reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); }
+ pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
+};
+
+template<typename _XprType, DirectionType Direction>
+struct indexed_based_stl_iterator_traits<subvector_stl_reverse_iterator<_XprType,Direction> >
+{
+ typedef _XprType XprType;
+ typedef subvector_stl_reverse_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
+ typedef subvector_stl_reverse_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
+};
+
+template<typename XprType, DirectionType Direction>
+class subvector_stl_reverse_iterator : public indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator<XprType,Direction> >
+{
+protected:
+
+ enum { is_lvalue = internal::is_lvalue<XprType>::value };
+
+ typedef indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator> Base;
+ using Base::m_index;
+ using Base::mp_xpr;
+
+ typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
+ typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
+
+
+public:
+ typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
+ typedef typename reference::PlainObject value_type;
+
+private:
+ class subvector_stl_reverse_iterator_ptr
+ {
+ public:
+ subvector_stl_reverse_iterator_ptr(const reference &subvector) : m_subvector(subvector) {}
+ reference* operator->() { return &m_subvector; }
+ private:
+ reference m_subvector;
+ };
+public:
+
+ typedef subvector_stl_reverse_iterator_ptr pointer;
+
+ subvector_stl_reverse_iterator() : Base() {}
+ subvector_stl_reverse_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
+
+ reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); }
+ reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); }
+ pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
+};
+
+} // namespace internal
+
+
+/** returns an iterator to the first element of the 1D vector or array
+ * \only_for_vectors
+ * \sa end(), cbegin()
+ */
+template<typename Derived>
+inline typename DenseBase<Derived>::iterator DenseBase<Derived>::begin()
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
+ return iterator(derived(), 0);
+}
+
+/** const version of begin() */
+template<typename Derived>
+inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::begin() const
+{
+ return cbegin();
+}
+
+/** returns a read-only const_iterator to the first element of the 1D vector or array
+ * \only_for_vectors
+ * \sa cend(), begin()
+ */
+template<typename Derived>
+inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cbegin() const
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
+ return const_iterator(derived(), 0);
+}
+
+/** returns an iterator to the element following the last element of the 1D vector or array
+ * \only_for_vectors
+ * \sa begin(), cend()
+ */
+template<typename Derived>
+inline typename DenseBase<Derived>::iterator DenseBase<Derived>::end()
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
+ return iterator(derived(), size());
+}
+
+/** const version of end() */
+template<typename Derived>
+inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::end() const
+{
+ return cend();
+}
+
+/** returns a read-only const_iterator to the element following the last element of the 1D vector or array
+ * \only_for_vectors
+ * \sa begin(), cend()
+ */
+template<typename Derived>
+inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cend() const
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
+ return const_iterator(derived(), size());
+}
+
+} // namespace Eigen
+
+#endif // EIGEN_STLITERATORS_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Stride.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Stride.h
index 513742f..6494d51 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Stride.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Stride.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_STRIDE_H
#define EIGEN_STRIDE_H
-namespace Eigen {
+namespace Eigen {
/** \class Stride
* \ingroup Core_Module
@@ -38,6 +38,10 @@
* \include Map_general_stride.cpp
* Output: \verbinclude Map_general_stride.out
*
+ * Both strides can be negative, however, a negative stride of -1 cannot be specified at compiletime
+ * because of the ambiguity with Dynamic which is defined to -1 (historically, negative strides were
+ * not allowed).
+ *
* \sa class InnerStride, class OuterStride, \ref TopicStorageOrders
*/
template<int _OuterStrideAtCompileTime, int _InnerStrideAtCompileTime>
@@ -55,6 +59,8 @@
Stride()
: m_outer(OuterStrideAtCompileTime), m_inner(InnerStrideAtCompileTime)
{
+ // FIXME: for Eigen 4 we should use DynamicIndex instead of Dynamic.
+ // FIXME: for Eigen 4 we should also unify this API with fix<>
eigen_assert(InnerStrideAtCompileTime != Dynamic && OuterStrideAtCompileTime != Dynamic);
}
@@ -63,7 +69,6 @@
Stride(Index outerStride, Index innerStride)
: m_outer(outerStride), m_inner(innerStride)
{
- eigen_assert(innerStride>=0 && outerStride>=0);
}
/** Copy constructor */
@@ -73,10 +78,10 @@
{}
/** \returns the outer stride */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index outer() const { return m_outer.value(); }
/** \returns the inner stride */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index inner() const { return m_inner.value(); }
protected:
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Swap.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Swap.h
index d702009..180a4e5 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Swap.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Swap.h
@@ -30,12 +30,13 @@
typedef typename Base::DstXprType DstXprType;
typedef swap_assign_op<Scalar> Functor;
- EIGEN_DEVICE_FUNC generic_dense_assignment_kernel(DstEvaluatorTypeT &dst, const SrcEvaluatorTypeT &src, const Functor &func, DstXprType& dstExpr)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ generic_dense_assignment_kernel(DstEvaluatorTypeT &dst, const SrcEvaluatorTypeT &src, const Functor &func, DstXprType& dstExpr)
: Base(dst, src, func, dstExpr)
{}
template<int StoreMode, int LoadMode, typename PacketType>
- void assignPacket(Index row, Index col)
+ EIGEN_STRONG_INLINE void assignPacket(Index row, Index col)
{
PacketType tmp = m_src.template packet<LoadMode,PacketType>(row,col);
const_cast<SrcEvaluatorTypeT&>(m_src).template writePacket<LoadMode>(row,col, m_dst.template packet<StoreMode,PacketType>(row,col));
@@ -43,7 +44,7 @@
}
template<int StoreMode, int LoadMode, typename PacketType>
- void assignPacket(Index index)
+ EIGEN_STRONG_INLINE void assignPacket(Index index)
{
PacketType tmp = m_src.template packet<LoadMode,PacketType>(index);
const_cast<SrcEvaluatorTypeT&>(m_src).template writePacket<LoadMode>(index, m_dst.template packet<StoreMode,PacketType>(index));
@@ -52,7 +53,7 @@
// TODO find a simple way not to have to copy/paste this function from generic_dense_assignment_kernel, by simple I mean no CRTP (Gael)
template<int StoreMode, int LoadMode, typename PacketType>
- void assignPacketByOuterInner(Index outer, Index inner)
+ EIGEN_STRONG_INLINE void assignPacketByOuterInner(Index outer, Index inner)
{
Index row = Base::rowIndexByOuterInner(outer, inner);
Index col = Base::colIndexByOuterInner(outer, inner);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpose.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpose.h
index 960dc45..2bc658f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpose.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpose.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_TRANSPOSE_H
#define EIGEN_TRANSPOSE_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
template<typename MatrixType>
@@ -61,24 +61,27 @@
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
EIGEN_DEVICE_FUNC
- explicit inline Transpose(MatrixType& matrix) : m_matrix(matrix) {}
+ explicit EIGEN_STRONG_INLINE Transpose(MatrixType& matrix) : m_matrix(matrix) {}
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Transpose)
- EIGEN_DEVICE_FUNC inline Index rows() const { return m_matrix.cols(); }
- EIGEN_DEVICE_FUNC inline Index cols() const { return m_matrix.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
/** \returns the nested expression */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const typename internal::remove_all<MatrixTypeNested>::type&
nestedExpression() const { return m_matrix; }
/** \returns the nested expression */
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename internal::remove_reference<MatrixTypeNested>::type&
nestedExpression() { return m_matrix; }
/** \internal */
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void resize(Index nrows, Index ncols) {
m_matrix.resize(ncols,nrows);
}
@@ -122,8 +125,10 @@
EIGEN_DENSE_PUBLIC_INTERFACE(Transpose<MatrixType>)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(TransposeImpl)
- EIGEN_DEVICE_FUNC inline Index innerStride() const { return derived().nestedExpression().innerStride(); }
- EIGEN_DEVICE_FUNC inline Index outerStride() const { return derived().nestedExpression().outerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Index innerStride() const { return derived().nestedExpression().innerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Index outerStride() const { return derived().nestedExpression().outerStride(); }
typedef typename internal::conditional<
internal::is_lvalue<MatrixType>::value,
@@ -131,18 +136,20 @@
const Scalar
>::type ScalarWithConstIfNotLvalue;
- EIGEN_DEVICE_FUNC inline ScalarWithConstIfNotLvalue* data() { return derived().nestedExpression().data(); }
- EIGEN_DEVICE_FUNC inline const Scalar* data() const { return derived().nestedExpression().data(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ ScalarWithConstIfNotLvalue* data() { return derived().nestedExpression().data(); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Scalar* data() const { return derived().nestedExpression().data(); }
// FIXME: shall we keep the const version of coeffRef?
- EIGEN_DEVICE_FUNC
- inline const Scalar& coeffRef(Index rowId, Index colId) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Scalar& coeffRef(Index rowId, Index colId) const
{
return derived().nestedExpression().coeffRef(colId, rowId);
}
- EIGEN_DEVICE_FUNC
- inline const Scalar& coeffRef(Index index) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const Scalar& coeffRef(Index index) const
{
return derived().nestedExpression().coeffRef(index);
}
@@ -170,7 +177,8 @@
*
* \sa transposeInPlace(), adjoint() */
template<typename Derived>
-inline Transpose<Derived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+Transpose<Derived>
DenseBase<Derived>::transpose()
{
return TransposeReturnType(derived());
@@ -182,7 +190,8 @@
*
* \sa transposeInPlace(), adjoint() */
template<typename Derived>
-inline typename DenseBase<Derived>::ConstTransposeReturnType
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename DenseBase<Derived>::ConstTransposeReturnType
DenseBase<Derived>::transpose() const
{
return ConstTransposeReturnType(derived());
@@ -208,7 +217,7 @@
*
* \sa adjointInPlace(), transpose(), conjugate(), class Transpose, class internal::scalar_conjugate_op */
template<typename Derived>
-inline const typename MatrixBase<Derived>::AdjointReturnType
+EIGEN_DEVICE_FUNC inline const typename MatrixBase<Derived>::AdjointReturnType
MatrixBase<Derived>::adjoint() const
{
return AdjointReturnType(this->transpose());
@@ -230,11 +239,10 @@
template<typename MatrixType>
struct inplace_transpose_selector<MatrixType,true,false> { // square matrix
static void run(MatrixType& m) {
- m.matrix().template triangularView<StrictlyUpper>().swap(m.matrix().transpose());
+ m.matrix().template triangularView<StrictlyUpper>().swap(m.matrix().transpose().template triangularView<StrictlyUpper>());
}
};
-// TODO: vectorized path is currently limited to LargestPacketSize x LargestPacketSize cases only.
template<typename MatrixType>
struct inplace_transpose_selector<MatrixType,true,true> { // PacketSize x PacketSize
static void run(MatrixType& m) {
@@ -251,16 +259,66 @@
}
};
+
+template <typename MatrixType, Index Alignment>
+void BlockedInPlaceTranspose(MatrixType& m) {
+ typedef typename MatrixType::Scalar Scalar;
+ typedef typename internal::packet_traits<typename MatrixType::Scalar>::type Packet;
+ const Index PacketSize = internal::packet_traits<Scalar>::size;
+ eigen_assert(m.rows() == m.cols());
+ int row_start = 0;
+ for (; row_start + PacketSize <= m.rows(); row_start += PacketSize) {
+ for (int col_start = row_start; col_start + PacketSize <= m.cols(); col_start += PacketSize) {
+ PacketBlock<Packet> A;
+ if (row_start == col_start) {
+ for (Index i=0; i<PacketSize; ++i)
+ A.packet[i] = m.template packetByOuterInner<Alignment>(row_start + i,col_start);
+ internal::ptranspose(A);
+ for (Index i=0; i<PacketSize; ++i)
+ m.template writePacket<Alignment>(m.rowIndexByOuterInner(row_start + i, col_start), m.colIndexByOuterInner(row_start + i,col_start), A.packet[i]);
+ } else {
+ PacketBlock<Packet> B;
+ for (Index i=0; i<PacketSize; ++i) {
+ A.packet[i] = m.template packetByOuterInner<Alignment>(row_start + i,col_start);
+ B.packet[i] = m.template packetByOuterInner<Alignment>(col_start + i, row_start);
+ }
+ internal::ptranspose(A);
+ internal::ptranspose(B);
+ for (Index i=0; i<PacketSize; ++i) {
+ m.template writePacket<Alignment>(m.rowIndexByOuterInner(row_start + i, col_start), m.colIndexByOuterInner(row_start + i,col_start), B.packet[i]);
+ m.template writePacket<Alignment>(m.rowIndexByOuterInner(col_start + i, row_start), m.colIndexByOuterInner(col_start + i,row_start), A.packet[i]);
+ }
+ }
+ }
+ }
+ for (Index row = row_start; row < m.rows(); ++row) {
+ m.matrix().row(row).head(row).swap(
+ m.matrix().col(row).head(row).transpose());
+ }
+}
+
template<typename MatrixType,bool MatchPacketSize>
-struct inplace_transpose_selector<MatrixType,false,MatchPacketSize> { // non square matrix
+struct inplace_transpose_selector<MatrixType,false,MatchPacketSize> { // non square or dynamic matrix
static void run(MatrixType& m) {
- if (m.rows()==m.cols())
- m.matrix().template triangularView<StrictlyUpper>().swap(m.matrix().transpose());
- else
+ typedef typename MatrixType::Scalar Scalar;
+ if (m.rows() == m.cols()) {
+ const Index PacketSize = internal::packet_traits<Scalar>::size;
+ if (!NumTraits<Scalar>::IsComplex && m.rows() >= PacketSize) {
+ if ((m.rows() % PacketSize) == 0)
+ BlockedInPlaceTranspose<MatrixType,internal::evaluator<MatrixType>::Alignment>(m);
+ else
+ BlockedInPlaceTranspose<MatrixType,Unaligned>(m);
+ }
+ else {
+ m.matrix().template triangularView<StrictlyUpper>().swap(m.matrix().transpose().template triangularView<StrictlyUpper>());
+ }
+ } else {
m = m.transpose().eval();
+ }
}
};
+
} // end namespace internal
/** This is the "in place" version of transpose(): it replaces \c *this by its own transpose.
@@ -278,12 +336,12 @@
* Notice however that this method is only useful if you want to replace a matrix by its own transpose.
* If you just need the transpose of a matrix, use transpose().
*
- * \note if the matrix is not square, then \c *this must be a resizable matrix.
+ * \note if the matrix is not square, then \c *this must be a resizable matrix.
* This excludes (non-square) fixed-size matrices, block-expressions and maps.
*
* \sa transpose(), adjoint(), adjointInPlace() */
template<typename Derived>
-inline void DenseBase<Derived>::transposeInPlace()
+EIGEN_DEVICE_FUNC inline void DenseBase<Derived>::transposeInPlace()
{
eigen_assert((rows() == cols() || (RowsAtCompileTime == Dynamic && ColsAtCompileTime == Dynamic))
&& "transposeInPlace() called on a non-square non-resizable matrix");
@@ -314,7 +372,7 @@
*
* \sa transpose(), adjoint(), transposeInPlace() */
template<typename Derived>
-inline void MatrixBase<Derived>::adjointInPlace()
+EIGEN_DEVICE_FUNC inline void MatrixBase<Derived>::adjointInPlace()
{
derived() = adjoint().eval();
}
@@ -393,7 +451,8 @@
template<typename Dst, typename Src>
void check_for_aliasing(const Dst &dst, const Src &src)
{
- internal::checkTransposeAliasing_impl<Dst, Src>::run(dst, src);
+ if((!Dst::IsVectorAtCompileTime) && dst.rows()>1 && dst.cols()>1)
+ internal::checkTransposeAliasing_impl<Dst, Src>::run(dst, src);
}
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpositions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpositions.h
index 86da5af..38a7b01 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpositions.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Transpositions.h
@@ -10,20 +10,22 @@
#ifndef EIGEN_TRANSPOSITIONS_H
#define EIGEN_TRANSPOSITIONS_H
-namespace Eigen {
+namespace Eigen {
template<typename Derived>
class TranspositionsBase
{
typedef internal::traits<Derived> Traits;
-
+
public:
typedef typename Traits::IndicesType IndicesType;
typedef typename IndicesType::Scalar StorageIndex;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
+ EIGEN_DEVICE_FUNC
Derived& derived() { return *static_cast<Derived*>(this); }
+ EIGEN_DEVICE_FUNC
const Derived& derived() const { return *static_cast<const Derived*>(this); }
/** Copies the \a other transpositions into \c *this */
@@ -33,26 +35,19 @@
indices() = other.indices();
return derived();
}
-
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** This is a special case of the templated operator=. Its purpose is to
- * prevent a default operator= from hiding the templated operator=.
- */
- Derived& operator=(const TranspositionsBase& other)
- {
- indices() = other.indices();
- return derived();
- }
- #endif
/** \returns the number of transpositions */
+ EIGEN_DEVICE_FUNC
Index size() const { return indices().size(); }
/** \returns the number of rows of the equivalent permutation matrix */
+ EIGEN_DEVICE_FUNC
Index rows() const { return indices().size(); }
/** \returns the number of columns of the equivalent permutation matrix */
+ EIGEN_DEVICE_FUNC
Index cols() const { return indices().size(); }
/** Direct access to the underlying index vector */
+ EIGEN_DEVICE_FUNC
inline const StorageIndex& coeff(Index i) const { return indices().coeff(i); }
/** Direct access to the underlying index vector */
inline StorageIndex& coeffRef(Index i) { return indices().coeffRef(i); }
@@ -66,8 +61,10 @@
inline StorageIndex& operator[](Index i) { return indices()(i); }
/** const version of indices(). */
+ EIGEN_DEVICE_FUNC
const IndicesType& indices() const { return derived().indices(); }
/** \returns a reference to the stored array representing the transpositions. */
+ EIGEN_DEVICE_FUNC
IndicesType& indices() { return derived().indices(); }
/** Resizes to given size. */
@@ -84,7 +81,7 @@
}
// FIXME: do we want such methods ?
- // might be usefull when the target matrix expression is complex, e.g.:
+ // might be useful when the target matrix expression is complex, e.g.:
// object.matrix().block(..,..,..,..) = trans * object.matrix().block(..,..,..,..);
/*
template<typename MatrixType>
@@ -171,12 +168,6 @@
inline Transpositions(const TranspositionsBase<OtherDerived>& other)
: m_indices(other.indices()) {}
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** Standard copy constructor. Defined only to prevent a default copy constructor
- * from hiding the other templated constructor */
- inline Transpositions(const Transpositions& other) : m_indices(other.indices()) {}
- #endif
-
/** Generic constructor from expression of the transposition indices. */
template<typename Other>
explicit inline Transpositions(const MatrixBase<Other>& indices) : m_indices(indices)
@@ -189,25 +180,16 @@
return Base::operator=(other);
}
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** This is a special case of the templated operator=. Its purpose is to
- * prevent a default operator= from hiding the templated operator=.
- */
- Transpositions& operator=(const Transpositions& other)
- {
- m_indices = other.m_indices;
- return *this;
- }
- #endif
-
/** Constructs an uninitialized permutation matrix of given size.
*/
inline Transpositions(Index size) : m_indices(size)
{}
/** const version of indices(). */
+ EIGEN_DEVICE_FUNC
const IndicesType& indices() const { return m_indices; }
/** \returns a reference to the stored array representing the transpositions. */
+ EIGEN_DEVICE_FUNC
IndicesType& indices() { return m_indices; }
protected:
@@ -265,9 +247,11 @@
#endif
/** const version of indices(). */
+ EIGEN_DEVICE_FUNC
const IndicesType& indices() const { return m_indices; }
-
+
/** \returns a reference to the stored array representing the transpositions. */
+ EIGEN_DEVICE_FUNC
IndicesType& indices() { return m_indices; }
protected:
@@ -306,21 +290,12 @@
return Base::operator=(other);
}
- #ifndef EIGEN_PARSED_BY_DOXYGEN
- /** This is a special case of the templated operator=. Its purpose is to
- * prevent a default operator= from hiding the templated operator=.
- */
- TranspositionsWrapper& operator=(const TranspositionsWrapper& other)
- {
- m_indices = other.m_indices;
- return *this;
- }
- #endif
-
/** const version of indices(). */
+ EIGEN_DEVICE_FUNC
const IndicesType& indices() const { return m_indices; }
/** \returns a reference to the stored array representing the transpositions. */
+ EIGEN_DEVICE_FUNC
IndicesType& indices() { return m_indices; }
protected:
@@ -374,9 +349,12 @@
explicit Transpose(const TranspositionType& t) : m_transpositions(t) {}
- Index size() const { return m_transpositions.size(); }
- Index rows() const { return m_transpositions.size(); }
- Index cols() const { return m_transpositions.size(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index size() const EIGEN_NOEXCEPT { return m_transpositions.size(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return m_transpositions.size(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return m_transpositions.size(); }
/** \returns the \a matrix with the inverse transpositions applied to the columns.
*/
@@ -395,7 +373,8 @@
{
return Product<Transpose, OtherDerived, AliasFreeProduct>(*this, matrix.derived());
}
-
+
+ EIGEN_DEVICE_FUNC
const TranspositionType& nestedExpression() const { return m_transpositions; }
protected:
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/TriangularMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/TriangularMatrix.h
index 9abb7e3..fdb8bc1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/TriangularMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/TriangularMatrix.h
@@ -11,12 +11,12 @@
#ifndef EIGEN_TRIANGULARMATRIX_H
#define EIGEN_TRIANGULARMATRIX_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
-
+
template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval;
-
+
}
/** \class TriangularBase
@@ -34,16 +34,16 @@
ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
-
+
SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
internal::traits<Derived>::ColsAtCompileTime>::ret),
/**< This is equal to the number of coefficients, i.e. the number of
* rows times the number of columns, or to \a Dynamic if this is not
* known at compile-time. \sa RowsAtCompileTime, ColsAtCompileTime */
-
+
MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
internal::traits<Derived>::MaxColsAtCompileTime>::ret)
-
+
};
typedef typename internal::traits<Derived>::Scalar Scalar;
typedef typename internal::traits<Derived>::StorageKind StorageKind;
@@ -53,18 +53,19 @@
typedef Derived const& Nested;
EIGEN_DEVICE_FUNC
- inline TriangularBase() { eigen_assert(!((Mode&UnitDiag) && (Mode&ZeroDiag))); }
+ inline TriangularBase() { eigen_assert(!((int(Mode) & int(UnitDiag)) && (int(Mode) & int(ZeroDiag)))); }
- EIGEN_DEVICE_FUNC
- inline Index rows() const { return derived().rows(); }
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return derived().cols(); }
- EIGEN_DEVICE_FUNC
- inline Index outerStride() const { return derived().outerStride(); }
- EIGEN_DEVICE_FUNC
- inline Index innerStride() const { return derived().innerStride(); }
-
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return derived().rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return derived().cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index outerStride() const EIGEN_NOEXCEPT { return derived().outerStride(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index innerStride() const EIGEN_NOEXCEPT { return derived().innerStride(); }
+
// dummy resize function
+ EIGEN_DEVICE_FUNC
void resize(Index rows, Index cols)
{
EIGEN_UNUSED_VARIABLE(rows);
@@ -155,7 +156,7 @@
* \param MatrixType the type of the object in which we are taking the triangular part
* \param Mode the kind of triangular matrix expression to construct. Can be #Upper,
* #Lower, #UnitUpper, #UnitLower, #StrictlyUpper, or #StrictlyLower.
- * This is in fact a bit field; it must have either #Upper or #Lower,
+ * This is in fact a bit field; it must have either #Upper or #Lower,
* and additionally it may have #UnitDiag or #ZeroDiag or neither.
*
* This class represents a triangular part of a matrix, not necessarily square. Strictly speaking, for rectangular
@@ -197,7 +198,8 @@
typedef typename internal::traits<TriangularView>::MatrixTypeNestedNonRef MatrixTypeNestedNonRef;
typedef typename internal::remove_all<typename MatrixType::ConjugateReturnType>::type MatrixConjugateReturnType;
-
+ typedef TriangularView<typename internal::add_const<MatrixType>::type, _Mode> ConstTriangularView;
+
public:
typedef typename internal::traits<TriangularView>::StorageKind StorageKind;
@@ -216,15 +218,15 @@
EIGEN_DEVICE_FUNC
explicit inline TriangularView(MatrixType& matrix) : m_matrix(matrix)
{}
-
+
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(TriangularView)
/** \copydoc EigenBase::rows() */
- EIGEN_DEVICE_FUNC
- inline Index rows() const { return m_matrix.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
/** \copydoc EigenBase::cols() */
- EIGEN_DEVICE_FUNC
- inline Index cols() const { return m_matrix.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
/** \returns a const reference to the nested expression */
EIGEN_DEVICE_FUNC
@@ -233,13 +235,25 @@
/** \returns a reference to the nested expression */
EIGEN_DEVICE_FUNC
NestedExpression& nestedExpression() { return m_matrix; }
-
+
typedef TriangularView<const MatrixConjugateReturnType,Mode> ConjugateReturnType;
/** \sa MatrixBase::conjugate() const */
EIGEN_DEVICE_FUNC
inline const ConjugateReturnType conjugate() const
{ return ConjugateReturnType(m_matrix.conjugate()); }
+ /** \returns an expression of the complex conjugate of \c *this if Cond==true,
+ * returns \c *this otherwise.
+ */
+ template<bool Cond>
+ EIGEN_DEVICE_FUNC
+ inline typename internal::conditional<Cond,ConjugateReturnType,ConstTriangularView>::type
+ conjugateIf() const
+ {
+ typedef typename internal::conditional<Cond,ConjugateReturnType,ConstTriangularView>::type ReturnType;
+ return ReturnType(m_matrix.template conjugateIf<Cond>());
+ }
+
typedef TriangularView<const typename MatrixType::AdjointReturnType,TransposeMode> AdjointReturnType;
/** \sa MatrixBase::adjoint() const */
EIGEN_DEVICE_FUNC
@@ -255,7 +269,7 @@
typename MatrixType::TransposeReturnType tmp(m_matrix);
return TransposeReturnType(tmp);
}
-
+
typedef TriangularView<const typename MatrixType::ConstTransposeReturnType,TransposeMode> ConstTransposeReturnType;
/** \sa MatrixBase::transpose() const */
EIGEN_DEVICE_FUNC
@@ -266,10 +280,10 @@
template<typename Other>
EIGEN_DEVICE_FUNC
- inline const Solve<TriangularView, Other>
+ inline const Solve<TriangularView, Other>
solve(const MatrixBase<Other>& other) const
{ return Solve<TriangularView, Other>(*this, other.derived()); }
-
+
// workaround MSVC ICE
#if EIGEN_COMP_MSVC
template<int Side, typename Other>
@@ -313,7 +327,7 @@
else
return m_matrix.diagonal().prod();
}
-
+
protected:
MatrixTypeNested m_matrix;
@@ -375,7 +389,7 @@
internal::call_assignment_no_alias(derived(), other.derived(), internal::sub_assign_op<Scalar,typename Other::Scalar>());
return derived();
}
-
+
/** \sa MatrixBase::operator*=() */
EIGEN_DEVICE_FUNC
TriangularViewType& operator*=(const typename internal::traits<MatrixType>::Scalar& other) { return *this = derived().nestedExpression() * other; }
@@ -433,14 +447,14 @@
TriangularViewType& operator=(const TriangularViewImpl& other)
{ return *this = other.derived().nestedExpression(); }
- /** \deprecated */
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ /** \deprecated */
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC
void lazyAssign(const TriangularBase<OtherDerived>& other);
- /** \deprecated */
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ /** \deprecated */
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC
void lazyAssign(const MatrixBase<OtherDerived>& other);
#endif
@@ -468,7 +482,7 @@
* \a Side==OnTheLeft (the default), or the right-inverse-multiply \a other * inverse(\c *this) if
* \a Side==OnTheRight.
*
- * Note that the template parameter \c Side can be ommitted, in which case \c Side==OnTheLeft
+ * Note that the template parameter \c Side can be omitted, in which case \c Side==OnTheLeft
*
* The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the
* diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this
@@ -486,7 +500,6 @@
* \sa TriangularView::solveInPlace()
*/
template<int Side, typename Other>
- EIGEN_DEVICE_FUNC
inline const internal::triangular_solve_retval<Side,TriangularViewType, Other>
solve(const MatrixBase<Other>& other) const;
@@ -495,7 +508,7 @@
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
*
- * Note that the template parameter \c Side can be ommitted, in which case \c Side==OnTheLeft
+ * Note that the template parameter \c Side can be omitted, in which case \c Side==OnTheLeft
*
* See TriangularView:solve() for the details.
*/
@@ -521,10 +534,10 @@
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
}
- /** \deprecated
- * Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */
+ /** Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */
template<typename OtherDerived>
- EIGEN_DEVICE_FUNC
+ /** \deprecated */
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC
void swap(MatrixBase<OtherDerived> const & other)
{
EIGEN_STATIC_ASSERT_LVALUE(OtherDerived);
@@ -556,7 +569,7 @@
// FIXME should we keep that possibility
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
-inline TriangularView<MatrixType, Mode>&
+EIGEN_DEVICE_FUNC inline TriangularView<MatrixType, Mode>&
TriangularViewImpl<MatrixType, Mode, Dense>::operator=(const MatrixBase<OtherDerived>& other)
{
internal::call_assignment_no_alias(derived(), other.derived(), internal::assign_op<Scalar,typename OtherDerived::Scalar>());
@@ -566,7 +579,7 @@
// FIXME should we keep that possibility
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
-void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const MatrixBase<OtherDerived>& other)
+EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const MatrixBase<OtherDerived>& other)
{
internal::call_assignment_no_alias(derived(), other.template triangularView<Mode>());
}
@@ -575,7 +588,7 @@
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
-inline TriangularView<MatrixType, Mode>&
+EIGEN_DEVICE_FUNC inline TriangularView<MatrixType, Mode>&
TriangularViewImpl<MatrixType, Mode, Dense>::operator=(const TriangularBase<OtherDerived>& other)
{
eigen_assert(Mode == int(OtherDerived::Mode));
@@ -585,7 +598,7 @@
template<typename MatrixType, unsigned int Mode>
template<typename OtherDerived>
-void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const TriangularBase<OtherDerived>& other)
+EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType, Mode, Dense>::lazyAssign(const TriangularBase<OtherDerived>& other)
{
eigen_assert(Mode == int(OtherDerived::Mode));
internal::call_assignment_no_alias(derived(), other.derived());
@@ -600,7 +613,7 @@
* If the matrix is triangular, the opposite part is set to zero. */
template<typename Derived>
template<typename DenseDerived>
-void TriangularBase<Derived>::evalTo(MatrixBase<DenseDerived> &other) const
+EIGEN_DEVICE_FUNC void TriangularBase<Derived>::evalTo(MatrixBase<DenseDerived> &other) const
{
evalToLazy(other.derived());
}
@@ -626,6 +639,7 @@
*/
template<typename Derived>
template<unsigned int Mode>
+EIGEN_DEVICE_FUNC
typename MatrixBase<Derived>::template TriangularViewReturnType<Mode>::Type
MatrixBase<Derived>::triangularView()
{
@@ -635,6 +649,7 @@
/** This is the const version of MatrixBase::triangularView() */
template<typename Derived>
template<unsigned int Mode>
+EIGEN_DEVICE_FUNC
typename MatrixBase<Derived>::template ConstTriangularViewReturnType<Mode>::Type
MatrixBase<Derived>::triangularView() const
{
@@ -700,7 +715,7 @@
namespace internal {
-
+
// TODO currently a triangular expression has the form TriangularView<.,.>
// in the future triangular-ness should be defined by the expression traits
// such that Transpose<TriangularView<.,.> > is valid. (currently TriangularBase::transpose() is overloaded to make it work)
@@ -717,6 +732,7 @@
{
typedef TriangularView<MatrixType,Mode> XprType;
typedef evaluator<typename internal::remove_all<MatrixType>::type> Base;
+ EIGEN_DEVICE_FUNC
unary_evaluator(const XprType &xpr) : Base(xpr.nestedExpression()) {}
};
@@ -728,7 +744,7 @@
template<typename Kernel, unsigned int Mode, int UnrollCount, bool ClearOpposite> struct triangular_assignment_loop;
-
+
/** \internal Specialization of the dense assignment kernel for triangular matrices.
* The main difference is that the triangular, diagonal, and opposite parts are processed through three different functions.
* \tparam UpLo must be either Lower or Upper
@@ -745,17 +761,17 @@
using Base::m_src;
using Base::m_functor;
public:
-
+
typedef typename Base::DstEvaluatorType DstEvaluatorType;
typedef typename Base::SrcEvaluatorType SrcEvaluatorType;
typedef typename Base::Scalar Scalar;
typedef typename Base::AssignmentTraits AssignmentTraits;
-
-
+
+
EIGEN_DEVICE_FUNC triangular_dense_assignment_kernel(DstEvaluatorType &dst, const SrcEvaluatorType &src, const Functor &func, DstXprType& dstExpr)
: Base(dst, src, func, dstExpr)
{}
-
+
#ifdef EIGEN_INTERNAL_DEBUGGING
EIGEN_DEVICE_FUNC void assignCoeff(Index row, Index col)
{
@@ -765,16 +781,16 @@
#else
using Base::assignCoeff;
#endif
-
+
EIGEN_DEVICE_FUNC void assignDiagonalCoeff(Index id)
{
if(Mode==UnitDiag && SetOpposite) m_functor.assignCoeff(m_dst.coeffRef(id,id), Scalar(1));
else if(Mode==ZeroDiag && SetOpposite) m_functor.assignCoeff(m_dst.coeffRef(id,id), Scalar(0));
else if(Mode==0) Base::assignCoeff(id,id);
}
-
+
EIGEN_DEVICE_FUNC void assignOppositeCoeff(Index row, Index col)
- {
+ {
eigen_internal_assert(row!=col);
if(SetOpposite)
m_functor.assignCoeff(m_dst.coeffRef(row,col), Scalar(0));
@@ -795,17 +811,17 @@
if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
dst.resize(dstRows, dstCols);
DstEvaluatorType dstEvaluator(dst);
-
+
typedef triangular_dense_assignment_kernel< Mode&(Lower|Upper),Mode&(UnitDiag|ZeroDiag|SelfAdjoint),SetOpposite,
DstEvaluatorType,SrcEvaluatorType,Functor> Kernel;
Kernel kernel(dstEvaluator, srcEvaluator, func, dst.const_cast_derived());
-
+
enum {
unroll = DstXprType::SizeAtCompileTime != Dynamic
&& SrcEvaluatorType::CoeffReadCost < HugeCost
- && DstXprType::SizeAtCompileTime * (DstEvaluatorType::CoeffReadCost+SrcEvaluatorType::CoeffReadCost) / 2 <= EIGEN_UNROLLING_LIMIT
+ && DstXprType::SizeAtCompileTime * (int(DstEvaluatorType::CoeffReadCost) + int(SrcEvaluatorType::CoeffReadCost)) / 2 <= EIGEN_UNROLLING_LIMIT
};
-
+
triangular_assignment_loop<Kernel, Mode, unroll ? int(DstXprType::SizeAtCompileTime) : Dynamic, SetOpposite>::run(kernel);
}
@@ -827,8 +843,8 @@
EIGEN_DEVICE_FUNC static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{
eigen_assert(int(DstXprType::Mode) == int(SrcXprType::Mode));
-
- call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func);
+
+ call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func);
}
};
@@ -837,7 +853,7 @@
{
EIGEN_DEVICE_FUNC static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{
- call_triangular_assignment_loop<SrcXprType::Mode, (SrcXprType::Mode&SelfAdjoint)==0>(dst, src, func);
+ call_triangular_assignment_loop<SrcXprType::Mode, (int(SrcXprType::Mode) & int(SelfAdjoint)) == 0>(dst, src, func);
}
};
@@ -846,7 +862,7 @@
{
EIGEN_DEVICE_FUNC static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{
- call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func);
+ call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func);
}
};
@@ -857,19 +873,19 @@
// FIXME: this is not very clean, perhaps this information should be provided by the kernel?
typedef typename Kernel::DstEvaluatorType DstEvaluatorType;
typedef typename DstEvaluatorType::XprType DstXprType;
-
+
enum {
col = (UnrollCount-1) / DstXprType::RowsAtCompileTime,
row = (UnrollCount-1) % DstXprType::RowsAtCompileTime
};
-
+
typedef typename Kernel::Scalar Scalar;
EIGEN_DEVICE_FUNC
static inline void run(Kernel &kernel)
{
triangular_assignment_loop<Kernel, Mode, UnrollCount-1, SetOpposite>::run(kernel);
-
+
if(row==col)
kernel.assignDiagonalCoeff(row);
else if( ((Mode&Lower) && row>col) || ((Mode&Upper) && row<col) )
@@ -912,10 +928,10 @@
}
else
i = maxi;
-
+
if(i<kernel.rows()) // then i==j
kernel.assignDiagonalCoeff(i++);
-
+
if (((Mode&Upper) && SetOpposite) || (Mode&Lower))
{
for(; i < kernel.rows(); ++i)
@@ -932,14 +948,14 @@
* If the matrix is triangular, the opposite part is set to zero. */
template<typename Derived>
template<typename DenseDerived>
-void TriangularBase<Derived>::evalToLazy(MatrixBase<DenseDerived> &other) const
+EIGEN_DEVICE_FUNC void TriangularBase<Derived>::evalToLazy(MatrixBase<DenseDerived> &other) const
{
other.derived().resize(this->rows(), this->cols());
- internal::call_triangular_assignment_loop<Derived::Mode,(Derived::Mode&SelfAdjoint)==0 /* SetOpposite */>(other.derived(), derived().nestedExpression());
+ internal::call_triangular_assignment_loop<Derived::Mode, (int(Derived::Mode) & int(SelfAdjoint)) == 0 /* SetOpposite */>(other.derived(), derived().nestedExpression());
}
namespace internal {
-
+
// Triangular = Product
template< typename DstXprType, typename Lhs, typename Rhs, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::assign_op<Scalar,typename Product<Lhs,Rhs,DefaultProduct>::Scalar>, Dense2Triangular>
@@ -952,7 +968,7 @@
if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
dst.resize(dstRows, dstCols);
- dst._assignProduct(src, 1, 0);
+ dst._assignProduct(src, Scalar(1), false);
}
};
@@ -963,7 +979,7 @@
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar,typename SrcXprType::Scalar> &)
{
- dst._assignProduct(src, 1, 1);
+ dst._assignProduct(src, Scalar(1), true);
}
};
@@ -974,7 +990,7 @@
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar,typename SrcXprType::Scalar> &)
{
- dst._assignProduct(src, -1, 1);
+ dst._assignProduct(src, Scalar(-1), true);
}
};
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorBlock.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorBlock.h
index d72fbf7..71c5b95 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorBlock.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorBlock.h
@@ -35,7 +35,7 @@
* It is the return type of DenseBase::segment(Index,Index) and DenseBase::segment<int>(Index) and
* most of the time this is the only way it is used.
*
- * However, if you want to directly maniputate sub-vector expressions,
+ * However, if you want to directly manipulate sub-vector expressions,
* for instance if you want to write a function returning such an expression, you
* will need to use this class.
*
@@ -71,8 +71,8 @@
/** Dynamic-size constructor
*/
- EIGEN_DEVICE_FUNC
- inline VectorBlock(VectorType& vector, Index start, Index size)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ VectorBlock(VectorType& vector, Index start, Index size)
: Base(vector,
IsColVector ? start : 0, IsColVector ? 0 : start,
IsColVector ? size : 1, IsColVector ? 1 : size)
@@ -82,8 +82,8 @@
/** Fixed-size constructor
*/
- EIGEN_DEVICE_FUNC
- inline VectorBlock(VectorType& vector, Index start)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ VectorBlock(VectorType& vector, Index start)
: Base(vector, IsColVector ? start : 0, IsColVector ? 0 : start)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(VectorBlock);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorwiseOp.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorwiseOp.h
index 4fe267e..870f4f1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorwiseOp.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/VectorwiseOp.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2008-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
@@ -65,10 +65,10 @@
explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp())
: m_matrix(mat), m_functor(func) {}
- EIGEN_DEVICE_FUNC
- Index rows() const { return (Direction==Vertical ? 1 : m_matrix.rows()); }
- EIGEN_DEVICE_FUNC
- Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return (Direction==Vertical ? 1 : m_matrix.rows()); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
EIGEN_DEVICE_FUNC
typename MatrixType::Nested nestedExpression() const { return m_matrix; }
@@ -81,39 +81,46 @@
const MemberOp m_functor;
};
-#define EIGEN_MEMBER_FUNCTOR(MEMBER,COST) \
- template <typename ResultType> \
- struct member_##MEMBER { \
- EIGEN_EMPTY_STRUCT_CTOR(member_##MEMBER) \
- typedef ResultType result_type; \
- template<typename Scalar, int Size> struct Cost \
- { enum { value = COST }; }; \
- template<typename XprType> \
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
- ResultType operator()(const XprType& mat) const \
- { return mat.MEMBER(); } \
+template<typename A,typename B> struct partial_redux_dummy_func;
+
+#define EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(MEMBER,COST,VECTORIZABLE,BINARYOP) \
+ template <typename ResultType,typename Scalar> \
+ struct member_##MEMBER { \
+ EIGEN_EMPTY_STRUCT_CTOR(member_##MEMBER) \
+ typedef ResultType result_type; \
+ typedef BINARYOP<Scalar,Scalar> BinaryOp; \
+ template<int Size> struct Cost { enum { value = COST }; }; \
+ enum { Vectorizable = VECTORIZABLE }; \
+ template<typename XprType> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+ ResultType operator()(const XprType& mat) const \
+ { return mat.MEMBER(); } \
+ BinaryOp binaryFunc() const { return BinaryOp(); } \
}
+#define EIGEN_MEMBER_FUNCTOR(MEMBER,COST) \
+ EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(MEMBER,COST,0,partial_redux_dummy_func)
+
namespace internal {
-EIGEN_MEMBER_FUNCTOR(squaredNorm, Size * NumTraits<Scalar>::MulCost + (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(norm, (Size+5) * NumTraits<Scalar>::MulCost + (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(stableNorm, (Size+5) * NumTraits<Scalar>::MulCost + (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(blueNorm, (Size+5) * NumTraits<Scalar>::MulCost + (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(hypotNorm, (Size-1) * functor_traits<scalar_hypot_op<Scalar> >::Cost );
-EIGEN_MEMBER_FUNCTOR(sum, (Size-1)*NumTraits<Scalar>::AddCost);
-EIGEN_MEMBER_FUNCTOR(mean, (Size-1)*NumTraits<Scalar>::AddCost + NumTraits<Scalar>::MulCost);
-EIGEN_MEMBER_FUNCTOR(minCoeff, (Size-1)*NumTraits<Scalar>::AddCost);
-EIGEN_MEMBER_FUNCTOR(maxCoeff, (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(all, (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(any, (Size-1)*NumTraits<Scalar>::AddCost);
EIGEN_MEMBER_FUNCTOR(count, (Size-1)*NumTraits<Scalar>::AddCost);
-EIGEN_MEMBER_FUNCTOR(prod, (Size-1)*NumTraits<Scalar>::MulCost);
-template <int p, typename ResultType>
+EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(sum, (Size-1)*NumTraits<Scalar>::AddCost, 1, internal::scalar_sum_op);
+EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(minCoeff, (Size-1)*NumTraits<Scalar>::AddCost, 1, internal::scalar_min_op);
+EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(maxCoeff, (Size-1)*NumTraits<Scalar>::AddCost, 1, internal::scalar_max_op);
+EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(prod, (Size-1)*NumTraits<Scalar>::MulCost, 1, internal::scalar_product_op);
+
+template <int p, typename ResultType,typename Scalar>
struct member_lpnorm {
typedef ResultType result_type;
- template<typename Scalar, int Size> struct Cost
+ enum { Vectorizable = 0 };
+ template<int Size> struct Cost
{ enum { value = (Size+5) * NumTraits<Scalar>::MulCost + (Size-1)*NumTraits<Scalar>::AddCost }; };
EIGEN_DEVICE_FUNC member_lpnorm() {}
template<typename XprType>
@@ -121,17 +128,20 @@
{ return mat.template lpNorm<p>(); }
};
-template <typename BinaryOp, typename Scalar>
+template <typename BinaryOpT, typename Scalar>
struct member_redux {
+ typedef BinaryOpT BinaryOp;
typedef typename result_of<
BinaryOp(const Scalar&,const Scalar&)
>::type result_type;
- template<typename _Scalar, int Size> struct Cost
- { enum { value = (Size-1) * functor_traits<BinaryOp>::Cost }; };
+
+ enum { Vectorizable = functor_traits<BinaryOp>::PacketAccess };
+ template<int Size> struct Cost { enum { value = (Size-1) * functor_traits<BinaryOp>::Cost }; };
EIGEN_DEVICE_FUNC explicit member_redux(const BinaryOp func) : m_functor(func) {}
template<typename Derived>
EIGEN_DEVICE_FUNC inline result_type operator()(const DenseBase<Derived>& mat) const
{ return mat.redux(m_functor); }
+ const BinaryOp& binaryFunc() const { return m_functor; }
const BinaryOp m_functor;
};
}
@@ -139,18 +149,38 @@
/** \class VectorwiseOp
* \ingroup Core_Module
*
- * \brief Pseudo expression providing partial reduction operations
+ * \brief Pseudo expression providing broadcasting and partial reduction operations
*
* \tparam ExpressionType the type of the object on which to do partial reductions
- * \tparam Direction indicates the direction of the redux (#Vertical or #Horizontal)
+ * \tparam Direction indicates whether to operate on columns (#Vertical) or rows (#Horizontal)
*
- * This class represents a pseudo expression with partial reduction features.
+ * This class represents a pseudo expression with broadcasting and partial reduction features.
* It is the return type of DenseBase::colwise() and DenseBase::rowwise()
- * and most of the time this is the only way it is used.
+ * and most of the time this is the only way it is explicitly used.
+ *
+ * To understand the logic of rowwise/colwise expression, let's consider a generic case `A.colwise().foo()`
+ * where `foo` is any method of `VectorwiseOp`. This expression is equivalent to applying `foo()` to each
+ * column of `A` and then re-assemble the outputs in a matrix expression:
+ * \code [A.col(0).foo(), A.col(1).foo(), ..., A.col(A.cols()-1).foo()] \endcode
*
* Example: \include MatrixBase_colwise.cpp
* Output: \verbinclude MatrixBase_colwise.out
*
+ * The begin() and end() methods are obviously exceptions to the previous rule as they
+ * return STL-compatible begin/end iterators to the rows or columns of the nested expression.
+ * Typical use cases include for-range-loop and calls to STL algorithms:
+ *
+ * Example: \include MatrixBase_colwise_iterator_cxx11.cpp
+ * Output: \verbinclude MatrixBase_colwise_iterator_cxx11.out
+ *
+ * For a partial reduction on an empty input, some rules apply.
+ * For the sake of clarity, let's consider a vertical reduction:
+ * - If the number of columns is zero, then a 1x0 row-major vector expression is returned.
+ * - Otherwise, if the number of rows is zero, then
+ * - a row vector of zeros is returned for sum-like reductions (sum, squaredNorm, norm, etc.)
+ * - a row vector of ones is returned for a product reduction (e.g., <code>MatrixXd(n,0).colwise().prod()</code>)
+ * - an assert is triggered for all other reductions (minCoeff,maxCoeff,redux(bin_op))
+ *
* \sa DenseBase::colwise(), DenseBase::rowwise(), class PartialReduxExpr
*/
template<typename ExpressionType, int Direction> class VectorwiseOp
@@ -163,11 +193,11 @@
typedef typename internal::ref_selector<ExpressionType>::non_const_type ExpressionTypeNested;
typedef typename internal::remove_all<ExpressionTypeNested>::type ExpressionTypeNestedCleaned;
- template<template<typename _Scalar> class Functor,
- typename Scalar_=Scalar> struct ReturnType
+ template<template<typename OutScalar,typename InputScalar> class Functor,
+ typename ReturnScalar=Scalar> struct ReturnType
{
typedef PartialReduxExpr<ExpressionType,
- Functor<Scalar_>,
+ Functor<ReturnScalar,Scalar>,
Direction
> Type;
};
@@ -187,23 +217,6 @@
protected:
- typedef typename internal::conditional<isVertical,
- typename ExpressionType::ColXpr,
- typename ExpressionType::RowXpr>::type SubVector;
- /** \internal
- * \returns the i-th subvector according to the \c Direction */
- EIGEN_DEVICE_FUNC
- SubVector subVector(Index i)
- {
- return SubVector(m_matrix.derived(),i);
- }
-
- /** \internal
- * \returns the number of subvectors in the direction \c Direction */
- EIGEN_DEVICE_FUNC
- Index subVectors() const
- { return isVertical?m_matrix.cols():m_matrix.rows(); }
-
template<typename OtherDerived> struct ExtendedType {
typedef Replicate<OtherDerived,
isVertical ? 1 : ExpressionType::RowsAtCompileTime,
@@ -258,42 +271,101 @@
EIGEN_DEVICE_FUNC
inline const ExpressionType& _expression() const { return m_matrix; }
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
+ /** STL-like <a href="https://en.cppreference.com/w/cpp/named_req/RandomAccessIterator">RandomAccessIterator</a>
+ * iterator type over the columns or rows as returned by the begin() and end() methods.
+ */
+ random_access_iterator_type iterator;
+ /** This is the const version of iterator (aka read-only) */
+ random_access_iterator_type const_iterator;
+ #else
+ typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator;
+ typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator;
+ typedef internal::subvector_stl_reverse_iterator<ExpressionType, DirectionType(Direction)> reverse_iterator;
+ typedef internal::subvector_stl_reverse_iterator<const ExpressionType, DirectionType(Direction)> const_reverse_iterator;
+ #endif
+
+ /** returns an iterator to the first row (rowwise) or column (colwise) of the nested expression.
+ * \sa end(), cbegin()
+ */
+ iterator begin() { return iterator (m_matrix, 0); }
+ /** const version of begin() */
+ const_iterator begin() const { return const_iterator(m_matrix, 0); }
+ /** const version of begin() */
+ const_iterator cbegin() const { return const_iterator(m_matrix, 0); }
+
+ /** returns a reverse iterator to the last row (rowwise) or column (colwise) of the nested expression.
+ * \sa rend(), crbegin()
+ */
+ reverse_iterator rbegin() { return reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
+ /** const version of rbegin() */
+ const_reverse_iterator rbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
+ /** const version of rbegin() */
+ const_reverse_iterator crbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
+
+ /** returns an iterator to the row (resp. column) following the last row (resp. column) of the nested expression
+ * \sa begin(), cend()
+ */
+ iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+ /** const version of end() */
+ const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+ /** const version of end() */
+ const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+
+ /** returns a reverse iterator to the row (resp. column) before the first row (resp. column) of the nested expression
+ * \sa begin(), cend()
+ */
+ reverse_iterator rend() { return reverse_iterator (m_matrix, -1); }
+ /** const version of rend() */
+ const_reverse_iterator rend() const { return const_reverse_iterator (m_matrix, -1); }
+ /** const version of rend() */
+ const_reverse_iterator crend() const { return const_reverse_iterator (m_matrix, -1); }
+
/** \returns a row or column vector expression of \c *this reduxed by \a func
*
* The template parameter \a BinaryOp is the type of the functor
* of the custom redux operator. Note that func must be an associative operator.
*
+ * \warning the size along the reduction direction must be strictly positive,
+ * otherwise an assertion is triggered.
+ *
* \sa class VectorwiseOp, DenseBase::colwise(), DenseBase::rowwise()
*/
template<typename BinaryOp>
EIGEN_DEVICE_FUNC
const typename ReduxReturnType<BinaryOp>::Type
redux(const BinaryOp& func = BinaryOp()) const
- { return typename ReduxReturnType<BinaryOp>::Type(_expression(), internal::member_redux<BinaryOp,Scalar>(func)); }
+ {
+ eigen_assert(redux_length()>0 && "you are using an empty matrix");
+ return typename ReduxReturnType<BinaryOp>::Type(_expression(), internal::member_redux<BinaryOp,Scalar>(func));
+ }
typedef typename ReturnType<internal::member_minCoeff>::Type MinCoeffReturnType;
typedef typename ReturnType<internal::member_maxCoeff>::Type MaxCoeffReturnType;
- typedef typename ReturnType<internal::member_squaredNorm,RealScalar>::Type SquaredNormReturnType;
- typedef typename ReturnType<internal::member_norm,RealScalar>::Type NormReturnType;
+ typedef PartialReduxExpr<const CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const ExpressionTypeNestedCleaned>,internal::member_sum<RealScalar,RealScalar>,Direction> SquaredNormReturnType;
+ typedef CwiseUnaryOp<internal::scalar_sqrt_op<RealScalar>, const SquaredNormReturnType> NormReturnType;
typedef typename ReturnType<internal::member_blueNorm,RealScalar>::Type BlueNormReturnType;
typedef typename ReturnType<internal::member_stableNorm,RealScalar>::Type StableNormReturnType;
typedef typename ReturnType<internal::member_hypotNorm,RealScalar>::Type HypotNormReturnType;
typedef typename ReturnType<internal::member_sum>::Type SumReturnType;
- typedef typename ReturnType<internal::member_mean>::Type MeanReturnType;
+ typedef EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(SumReturnType,Scalar,quotient) MeanReturnType;
typedef typename ReturnType<internal::member_all>::Type AllReturnType;
typedef typename ReturnType<internal::member_any>::Type AnyReturnType;
- typedef PartialReduxExpr<ExpressionType, internal::member_count<Index>, Direction> CountReturnType;
+ typedef PartialReduxExpr<ExpressionType, internal::member_count<Index,Scalar>, Direction> CountReturnType;
typedef typename ReturnType<internal::member_prod>::Type ProdReturnType;
typedef Reverse<const ExpressionType, Direction> ConstReverseReturnType;
typedef Reverse<ExpressionType, Direction> ReverseReturnType;
template<int p> struct LpNormReturnType {
- typedef PartialReduxExpr<ExpressionType, internal::member_lpnorm<p,RealScalar>,Direction> Type;
+ typedef PartialReduxExpr<ExpressionType, internal::member_lpnorm<p,RealScalar,Scalar>,Direction> Type;
};
/** \returns a row (or column) vector expression of the smallest coefficient
* of each column (or row) of the referenced expression.
*
+ * \warning the size along the reduction direction must be strictly positive,
+ * otherwise an assertion is triggered.
+ *
* \warning the result is undefined if \c *this contains NaN.
*
* Example: \include PartialRedux_minCoeff.cpp
@@ -302,11 +374,17 @@
* \sa DenseBase::minCoeff() */
EIGEN_DEVICE_FUNC
const MinCoeffReturnType minCoeff() const
- { return MinCoeffReturnType(_expression()); }
+ {
+ eigen_assert(redux_length()>0 && "you are using an empty matrix");
+ return MinCoeffReturnType(_expression());
+ }
/** \returns a row (or column) vector expression of the largest coefficient
* of each column (or row) of the referenced expression.
*
+ * \warning the size along the reduction direction must be strictly positive,
+ * otherwise an assertion is triggered.
+ *
* \warning the result is undefined if \c *this contains NaN.
*
* Example: \include PartialRedux_maxCoeff.cpp
@@ -315,7 +393,10 @@
* \sa DenseBase::maxCoeff() */
EIGEN_DEVICE_FUNC
const MaxCoeffReturnType maxCoeff() const
- { return MaxCoeffReturnType(_expression()); }
+ {
+ eigen_assert(redux_length()>0 && "you are using an empty matrix");
+ return MaxCoeffReturnType(_expression());
+ }
/** \returns a row (or column) vector expression of the squared norm
* of each column (or row) of the referenced expression.
@@ -327,7 +408,7 @@
* \sa DenseBase::squaredNorm() */
EIGEN_DEVICE_FUNC
const SquaredNormReturnType squaredNorm() const
- { return SquaredNormReturnType(_expression()); }
+ { return SquaredNormReturnType(m_matrix.cwiseAbs2()); }
/** \returns a row (or column) vector expression of the norm
* of each column (or row) of the referenced expression.
@@ -339,7 +420,7 @@
* \sa DenseBase::norm() */
EIGEN_DEVICE_FUNC
const NormReturnType norm() const
- { return NormReturnType(_expression()); }
+ { return NormReturnType(squaredNorm()); }
/** \returns a row (or column) vector expression of the norm
* of each column (or row) of the referenced expression.
@@ -404,7 +485,7 @@
* \sa DenseBase::mean() */
EIGEN_DEVICE_FUNC
const MeanReturnType mean() const
- { return MeanReturnType(_expression()); }
+ { return sum() / Scalar(Direction==Vertical?m_matrix.rows():m_matrix.cols()); }
/** \returns a row (or column) vector expression representing
* whether \b all coefficients of each respective column (or row) are \c true.
@@ -500,7 +581,7 @@
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
//eigen_assert((m_matrix.isNull()) == (other.isNull())); FIXME
- return const_cast<ExpressionType&>(m_matrix = extendedTo(other.derived()));
+ return m_matrix = extendedTo(other.derived());
}
/** Adds the vector \a other to each subvector of \c *this */
@@ -510,7 +591,7 @@
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
- return const_cast<ExpressionType&>(m_matrix += extendedTo(other.derived()));
+ return m_matrix += extendedTo(other.derived());
}
/** Substracts the vector \a other to each subvector of \c *this */
@@ -520,7 +601,7 @@
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
- return const_cast<ExpressionType&>(m_matrix -= extendedTo(other.derived()));
+ return m_matrix -= extendedTo(other.derived());
}
/** Multiples each subvector of \c *this by the vector \a other */
@@ -532,7 +613,7 @@
EIGEN_STATIC_ASSERT_ARRAYXPR(ExpressionType)
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
m_matrix *= extendedTo(other.derived());
- return const_cast<ExpressionType&>(m_matrix);
+ return m_matrix;
}
/** Divides each subvector of \c *this by the vector \a other */
@@ -544,7 +625,7 @@
EIGEN_STATIC_ASSERT_ARRAYXPR(ExpressionType)
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
m_matrix /= extendedTo(other.derived());
- return const_cast<ExpressionType&>(m_matrix);
+ return m_matrix;
}
/** Returns the expression of the sum of the vector \a other to each subvector of \c *this */
@@ -609,7 +690,7 @@
EIGEN_DEVICE_FUNC
CwiseBinaryOp<internal::scalar_quotient_op<Scalar>,
const ExpressionTypeNestedCleaned,
- const typename OppositeExtendedType<typename ReturnType<internal::member_norm,RealScalar>::Type>::Type>
+ const typename OppositeExtendedType<NormReturnType>::Type>
normalized() const { return m_matrix.cwiseQuotient(extendedToOpposite(this->norm())); }
@@ -658,7 +739,15 @@
EIGEN_DEVICE_FUNC
const HNormalizedReturnType hnormalized() const;
+# ifdef EIGEN_VECTORWISEOP_PLUGIN
+# include EIGEN_VECTORWISEOP_PLUGIN
+# endif
+
protected:
+ Index redux_length() const
+ {
+ return Direction==Vertical ? m_matrix.rows() : m_matrix.cols();
+ }
ExpressionTypeNested m_matrix;
};
@@ -670,7 +759,7 @@
* \sa rowwise(), class VectorwiseOp, \ref TutorialReductionsVisitorsBroadcasting
*/
template<typename Derived>
-inline typename DenseBase<Derived>::ColwiseReturnType
+EIGEN_DEVICE_FUNC inline typename DenseBase<Derived>::ColwiseReturnType
DenseBase<Derived>::colwise()
{
return ColwiseReturnType(derived());
@@ -684,7 +773,7 @@
* \sa colwise(), class VectorwiseOp, \ref TutorialReductionsVisitorsBroadcasting
*/
template<typename Derived>
-inline typename DenseBase<Derived>::RowwiseReturnType
+EIGEN_DEVICE_FUNC inline typename DenseBase<Derived>::RowwiseReturnType
DenseBase<Derived>::rowwise()
{
return RowwiseReturnType(derived());
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Visitor.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Visitor.h
index 54c1883..00bcca8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Visitor.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/Visitor.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_VISITOR_H
#define EIGEN_VISITOR_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
@@ -40,6 +40,14 @@
}
};
+// This specialization enables visitors on empty matrices at compile-time
+template<typename Visitor, typename Derived>
+struct visitor_impl<Visitor, Derived, 0> {
+ EIGEN_DEVICE_FUNC
+ static inline void run(const Derived &/*mat*/, Visitor& /*visitor*/)
+ {}
+};
+
template<typename Visitor, typename Derived>
struct visitor_impl<Visitor, Derived, Dynamic>
{
@@ -62,22 +70,22 @@
public:
EIGEN_DEVICE_FUNC
explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
-
+
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
-
+
enum {
RowsAtCompileTime = XprType::RowsAtCompileTime,
CoeffReadCost = internal::evaluator<XprType>::CoeffReadCost
};
-
- EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
- EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }
- EIGEN_DEVICE_FUNC Index size() const { return m_xpr.size(); }
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_xpr.size(); }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
{ return m_evaluator.coeff(row, col); }
-
+
protected:
internal::evaluator<XprType> m_evaluator;
const XprType &m_xpr;
@@ -99,6 +107,8 @@
* \note compared to one or two \em for \em loops, visitors offer automatic
* unrolling for small fixed size matrix.
*
+ * \note if the matrix is empty, then the visitor is left unchanged.
+ *
* \sa minCoeff(Index*,Index*), maxCoeff(Index*,Index*), DenseBase::redux()
*/
template<typename Derived>
@@ -106,12 +116,15 @@
EIGEN_DEVICE_FUNC
void DenseBase<Derived>::visit(Visitor& visitor) const
{
+ if(size()==0)
+ return;
+
typedef typename internal::visitor_evaluator<Derived> ThisEvaluator;
ThisEvaluator thisEval(derived());
-
+
enum {
unroll = SizeAtCompileTime != Dynamic
- && SizeAtCompileTime * ThisEvaluator::CoeffReadCost + (SizeAtCompileTime-1) * internal::functor_traits<Visitor>::Cost <= EIGEN_UNROLLING_LIMIT
+ && SizeAtCompileTime * int(ThisEvaluator::CoeffReadCost) + (SizeAtCompileTime-1) * int(internal::functor_traits<Visitor>::Cost) <= EIGEN_UNROLLING_LIMIT
};
return internal::visitor_impl<Visitor, ThisEvaluator, unroll ? int(SizeAtCompileTime) : Dynamic>::run(thisEval, visitor);
}
@@ -124,6 +137,9 @@
template <typename Derived>
struct coeff_visitor
{
+ // default initialization to avoid countless invalid maybe-uninitialized warnings by gcc
+ EIGEN_DEVICE_FUNC
+ coeff_visitor() : row(-1), col(-1), res(0) {}
typedef typename Derived::Scalar Scalar;
Index row, col;
Scalar res;
@@ -141,7 +157,7 @@
*
* \sa DenseBase::minCoeff(Index*, Index*)
*/
-template <typename Derived>
+template <typename Derived, int NaNPropagation>
struct min_coeff_visitor : coeff_visitor<Derived>
{
typedef typename Derived::Scalar Scalar;
@@ -157,8 +173,40 @@
}
};
-template<typename Scalar>
-struct functor_traits<min_coeff_visitor<Scalar> > {
+template <typename Derived>
+struct min_coeff_visitor<Derived, PropagateNumbers> : coeff_visitor<Derived>
+{
+ typedef typename Derived::Scalar Scalar;
+ EIGEN_DEVICE_FUNC
+ void operator() (const Scalar& value, Index i, Index j)
+ {
+ if((numext::isnan)(this->res) || (!(numext::isnan)(value) && value < this->res))
+ {
+ this->res = value;
+ this->row = i;
+ this->col = j;
+ }
+ }
+};
+
+template <typename Derived>
+struct min_coeff_visitor<Derived, PropagateNaN> : coeff_visitor<Derived>
+{
+ typedef typename Derived::Scalar Scalar;
+ EIGEN_DEVICE_FUNC
+ void operator() (const Scalar& value, Index i, Index j)
+ {
+ if((numext::isnan)(value) || value < this->res)
+ {
+ this->res = value;
+ this->row = i;
+ this->col = j;
+ }
+ }
+};
+
+template<typename Scalar, int NaNPropagation>
+ struct functor_traits<min_coeff_visitor<Scalar, NaNPropagation> > {
enum {
Cost = NumTraits<Scalar>::AddCost
};
@@ -169,10 +217,10 @@
*
* \sa DenseBase::maxCoeff(Index*, Index*)
*/
-template <typename Derived>
+template <typename Derived, int NaNPropagation>
struct max_coeff_visitor : coeff_visitor<Derived>
{
- typedef typename Derived::Scalar Scalar;
+ typedef typename Derived::Scalar Scalar;
EIGEN_DEVICE_FUNC
void operator() (const Scalar& value, Index i, Index j)
{
@@ -185,8 +233,40 @@
}
};
-template<typename Scalar>
-struct functor_traits<max_coeff_visitor<Scalar> > {
+template <typename Derived>
+struct max_coeff_visitor<Derived, PropagateNumbers> : coeff_visitor<Derived>
+{
+ typedef typename Derived::Scalar Scalar;
+ EIGEN_DEVICE_FUNC
+ void operator() (const Scalar& value, Index i, Index j)
+ {
+ if((numext::isnan)(this->res) || (!(numext::isnan)(value) && value > this->res))
+ {
+ this->res = value;
+ this->row = i;
+ this->col = j;
+ }
+ }
+};
+
+template <typename Derived>
+struct max_coeff_visitor<Derived, PropagateNaN> : coeff_visitor<Derived>
+{
+ typedef typename Derived::Scalar Scalar;
+ EIGEN_DEVICE_FUNC
+ void operator() (const Scalar& value, Index i, Index j)
+ {
+ if((numext::isnan)(value) || value > this->res)
+ {
+ this->res = value;
+ this->row = i;
+ this->col = j;
+ }
+ }
+};
+
+template<typename Scalar, int NaNPropagation>
+struct functor_traits<max_coeff_visitor<Scalar, NaNPropagation> > {
enum {
Cost = NumTraits<Scalar>::AddCost
};
@@ -196,17 +276,24 @@
/** \fn DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
* \returns the minimum of all coefficients of *this and puts in *row and *col its location.
- * \warning the result is undefined if \c *this contains NaN.
+ *
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(Index*), DenseBase::maxCoeff(Index*,Index*), DenseBase::visit(), DenseBase::minCoeff()
*/
template<typename Derived>
-template<typename IndexType>
+template<int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar
DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
{
- internal::min_coeff_visitor<Derived> minVisitor;
+ eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
+
+ internal::min_coeff_visitor<Derived, NaNPropagation> minVisitor;
this->visit(minVisitor);
*rowId = minVisitor.row;
if (colId) *colId = minVisitor.col;
@@ -214,18 +301,25 @@
}
/** \returns the minimum of all coefficients of *this and puts in *index its location.
- * \warning the result is undefined if \c *this contains NaN.
+ *
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::minCoeff()
*/
template<typename Derived>
-template<typename IndexType>
+template<int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar
DenseBase<Derived>::minCoeff(IndexType* index) const
{
+ eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
+
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- internal::min_coeff_visitor<Derived> minVisitor;
+ internal::min_coeff_visitor<Derived, NaNPropagation> minVisitor;
this->visit(minVisitor);
*index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
return minVisitor.res;
@@ -233,17 +327,24 @@
/** \fn DenseBase<Derived>::maxCoeff(IndexType* rowId, IndexType* colId) const
* \returns the maximum of all coefficients of *this and puts in *row and *col its location.
- * \warning the result is undefined if \c *this contains NaN.
+ *
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::maxCoeff()
*/
template<typename Derived>
-template<typename IndexType>
+template<int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar
DenseBase<Derived>::maxCoeff(IndexType* rowPtr, IndexType* colPtr) const
{
- internal::max_coeff_visitor<Derived> maxVisitor;
+ eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
+
+ internal::max_coeff_visitor<Derived, NaNPropagation> maxVisitor;
this->visit(maxVisitor);
*rowPtr = maxVisitor.row;
if (colPtr) *colPtr = maxVisitor.col;
@@ -251,18 +352,25 @@
}
/** \returns the maximum of all coefficients of *this and puts in *index its location.
- * \warning the result is undefined if \c *this contains NaN.
+ *
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visitor(), DenseBase::maxCoeff()
*/
template<typename Derived>
-template<typename IndexType>
+template<int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar
DenseBase<Derived>::maxCoeff(IndexType* index) const
{
+ eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
+
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- internal::max_coeff_visitor<Derived> maxVisitor;
+ internal::max_coeff_visitor<Derived, NaNPropagation> maxVisitor;
this->visit(maxVisitor);
*index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
return maxVisitor.res;
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/Complex.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/Complex.h
index 7fa6196..ab7bd6c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/Complex.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/Complex.h
@@ -22,6 +22,7 @@
__m256 v;
};
+#ifndef EIGEN_VECTORIZE_AVX512
template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet4cf type;
@@ -37,6 +38,7 @@
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -44,8 +46,20 @@
HasSetLinear = 0
};
};
+#endif
-template<> struct unpacket_traits<Packet4cf> { typedef std::complex<float> type; enum {size=4, alignment=Aligned32}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet4cf> {
+ typedef std::complex<float> type;
+ typedef Packet2cf half;
+ typedef Packet8f as_real;
+ enum {
+ size=4,
+ alignment=Aligned32,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet4cf padd<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf psub<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); }
@@ -67,10 +81,17 @@
return Packet4cf(result);
}
+template <>
+EIGEN_STRONG_INLINE Packet4cf pcmp_eq(const Packet4cf& a, const Packet4cf& b) {
+ __m256 eq = _mm256_cmp_ps(a.v, b.v, _CMP_EQ_OQ);
+ return Packet4cf(_mm256_and_ps(eq, _mm256_permute_ps(eq, 0xb1)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cf ptrue<Packet4cf>(const Packet4cf& a) { return Packet4cf(ptrue(Packet8f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet4cf pand <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_and_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf por <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_or_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf pxor <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_xor_ps(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet4cf pandnot<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cf pandnot<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(b.v,a.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf pload <Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cf(pload<Packet8f>(&numext::real_ref(*from))); }
template<> EIGEN_STRONG_INLINE Packet4cf ploadu<Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cf(ploadu<Packet8f>(&numext::real_ref(*from))); }
@@ -140,70 +161,12 @@
Packet2cf(_mm256_extractf128_ps(a.v,1))));
}
-template<> EIGEN_STRONG_INLINE Packet4cf preduxp<Packet4cf>(const Packet4cf* vecs)
-{
- Packet8f t0 = _mm256_shuffle_ps(vecs[0].v, vecs[0].v, _MM_SHUFFLE(3, 1, 2 ,0));
- Packet8f t1 = _mm256_shuffle_ps(vecs[1].v, vecs[1].v, _MM_SHUFFLE(3, 1, 2 ,0));
- t0 = _mm256_hadd_ps(t0,t1);
- Packet8f t2 = _mm256_shuffle_ps(vecs[2].v, vecs[2].v, _MM_SHUFFLE(3, 1, 2 ,0));
- Packet8f t3 = _mm256_shuffle_ps(vecs[3].v, vecs[3].v, _MM_SHUFFLE(3, 1, 2 ,0));
- t2 = _mm256_hadd_ps(t2,t3);
-
- t1 = _mm256_permute2f128_ps(t0,t2, 0 + (2<<4));
- t3 = _mm256_permute2f128_ps(t0,t2, 1 + (3<<4));
-
- return Packet4cf(_mm256_add_ps(t1,t3));
-}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet4cf>(const Packet4cf& a)
{
return predux_mul(pmul(Packet2cf(_mm256_extractf128_ps(a.v, 0)),
Packet2cf(_mm256_extractf128_ps(a.v, 1))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet4cf>
-{
- static EIGEN_STRONG_INLINE void run(Packet4cf& first, const Packet4cf& second)
- {
- if (Offset==0) return;
- palign_impl<Offset*2,Packet8f>::run(first.v, second.v);
- }
-};
-
-template<> struct conj_helper<Packet4cf, Packet4cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet4cf, Packet4cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet4cf, Packet4cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f)
template<> EIGEN_STRONG_INLINE Packet4cf pdiv<Packet4cf>(const Packet4cf& a, const Packet4cf& b)
@@ -228,6 +191,7 @@
__m256d v;
};
+#ifndef EIGEN_VECTORIZE_AVX512
template<> struct packet_traits<std::complex<double> > : default_packet_traits
{
typedef Packet2cd type;
@@ -243,6 +207,7 @@
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -250,8 +215,20 @@
HasSetLinear = 0
};
};
+#endif
-template<> struct unpacket_traits<Packet2cd> { typedef std::complex<double> type; enum {size=2, alignment=Aligned32}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet2cd> {
+ typedef std::complex<double> type;
+ typedef Packet1cd half;
+ typedef Packet4d as_real;
+ enum {
+ size=2,
+ alignment=Aligned32,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet2cd padd<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd psub<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); }
@@ -272,10 +249,17 @@
return Packet2cd(_mm256_addsub_pd(even, odd));
}
+template <>
+EIGEN_STRONG_INLINE Packet2cd pcmp_eq(const Packet2cd& a, const Packet2cd& b) {
+ __m256d eq = _mm256_cmp_pd(a.v, b.v, _CMP_EQ_OQ);
+ return Packet2cd(pand(eq, _mm256_permute_pd(eq, 0x5)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cd ptrue<Packet2cd>(const Packet2cd& a) { return Packet2cd(ptrue(Packet4d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet2cd pand <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_and_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd por <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_or_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd pxor <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_xor_pd(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cd pandnot<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cd pandnot<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(b.v,a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd pload <Packet2cd>(const std::complex<double>* from)
{ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cd(pload<Packet4d>((const double*)from)); }
@@ -327,63 +311,12 @@
Packet1cd(_mm256_extractf128_pd(a.v,1))));
}
-template<> EIGEN_STRONG_INLINE Packet2cd preduxp<Packet2cd>(const Packet2cd* vecs)
-{
- Packet4d t0 = _mm256_permute2f128_pd(vecs[0].v,vecs[1].v, 0 + (2<<4));
- Packet4d t1 = _mm256_permute2f128_pd(vecs[0].v,vecs[1].v, 1 + (3<<4));
-
- return Packet2cd(_mm256_add_pd(t0,t1));
-}
-
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet2cd>(const Packet2cd& a)
{
return predux(pmul(Packet1cd(_mm256_extractf128_pd(a.v,0)),
Packet1cd(_mm256_extractf128_pd(a.v,1))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet2cd& first, const Packet2cd& second)
- {
- if (Offset==0) return;
- palign_impl<Offset*2,Packet4d>::run(first.v, second.v);
- }
-};
-
-template<> struct conj_helper<Packet2cd, Packet2cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet2cd, Packet2cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet2cd, Packet2cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d)
template<> EIGEN_STRONG_INLINE Packet2cd pdiv<Packet2cd>(const Packet2cd& a, const Packet2cd& b)
@@ -424,24 +357,12 @@
kernel.packet[0].v = tmp;
}
-template<> EIGEN_STRONG_INLINE Packet4cf pinsertfirst(const Packet4cf& a, std::complex<float> b)
-{
- return Packet4cf(_mm256_blend_ps(a.v,pset1<Packet4cf>(b).v,1|2));
+template<> EIGEN_STRONG_INLINE Packet2cd psqrt<Packet2cd>(const Packet2cd& a) {
+ return psqrt_complex<Packet2cd>(a);
}
-template<> EIGEN_STRONG_INLINE Packet2cd pinsertfirst(const Packet2cd& a, std::complex<double> b)
-{
- return Packet2cd(_mm256_blend_pd(a.v,pset1<Packet2cd>(b).v,1|2));
-}
-
-template<> EIGEN_STRONG_INLINE Packet4cf pinsertlast(const Packet4cf& a, std::complex<float> b)
-{
- return Packet4cf(_mm256_blend_ps(a.v,pset1<Packet4cf>(b).v,(1<<7)|(1<<6)));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2cd pinsertlast(const Packet2cd& a, std::complex<double> b)
-{
- return Packet2cd(_mm256_blend_pd(a.v,pset1<Packet2cd>(b).v,(1<<3)|(1<<2)));
+template<> EIGEN_STRONG_INLINE Packet4cf psqrt<Packet4cf>(const Packet4cf& a) {
+ return psqrt_complex<Packet4cf>(a);
}
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/MathFunctions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/MathFunctions.h
index 6af67ce..67041c8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_MATH_FUNCTIONS_AVX_H
#define EIGEN_MATH_FUNCTIONS_AVX_H
-/* The sin, cos, exp, and log functions of this file are loosely derived from
+/* The sin and cos functions of this file are loosely derived from
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
*/
@@ -18,187 +18,50 @@
namespace internal {
-inline Packet8i pshiftleft(Packet8i v, int n)
-{
-#ifdef EIGEN_VECTORIZE_AVX2
- return _mm256_slli_epi32(v, n);
-#else
- __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(v, 0), n);
- __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(v, 1), n);
- return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
-#endif
-}
-
-inline Packet8f pshiftright(Packet8f v, int n)
-{
-#ifdef EIGEN_VECTORIZE_AVX2
- return _mm256_cvtepi32_ps(_mm256_srli_epi32(_mm256_castps_si256(v), n));
-#else
- __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(v), 0), n);
- __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(v), 1), n);
- return _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1));
-#endif
-}
-
-// Sine function
-// Computes sin(x) by wrapping x to the interval [-Pi/4,3*Pi/4] and
-// evaluating interpolants in [-Pi/4,Pi/4] or [Pi/4,3*Pi/4]. The interpolants
-// are (anti-)symmetric and thus have only odd/even coefficients
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
psin<Packet8f>(const Packet8f& _x) {
- Packet8f x = _x;
-
- // Some useful values.
- _EIGEN_DECLARE_CONST_Packet8i(one, 1);
- _EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
- _EIGEN_DECLARE_CONST_Packet8f(two, 2.0f);
- _EIGEN_DECLARE_CONST_Packet8f(one_over_four, 0.25f);
- _EIGEN_DECLARE_CONST_Packet8f(one_over_pi, 3.183098861837907e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(neg_pi_first, -3.140625000000000e+00f);
- _EIGEN_DECLARE_CONST_Packet8f(neg_pi_second, -9.670257568359375e-04f);
- _EIGEN_DECLARE_CONST_Packet8f(neg_pi_third, -6.278329571784980e-07f);
- _EIGEN_DECLARE_CONST_Packet8f(four_over_pi, 1.273239544735163e+00f);
-
- // Map x from [-Pi/4,3*Pi/4] to z in [-1,3] and subtract the shifted period.
- Packet8f z = pmul(x, p8f_one_over_pi);
- Packet8f shift = _mm256_floor_ps(padd(z, p8f_one_over_four));
- x = pmadd(shift, p8f_neg_pi_first, x);
- x = pmadd(shift, p8f_neg_pi_second, x);
- x = pmadd(shift, p8f_neg_pi_third, x);
- z = pmul(x, p8f_four_over_pi);
-
- // Make a mask for the entries that need flipping, i.e. wherever the shift
- // is odd.
- Packet8i shift_ints = _mm256_cvtps_epi32(shift);
- Packet8i shift_isodd = _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(shift_ints), _mm256_castsi256_ps(p8i_one)));
- Packet8i sign_flip_mask = pshiftleft(shift_isodd, 31);
-
- // Create a mask for which interpolant to use, i.e. if z > 1, then the mask
- // is set to ones for that entry.
- Packet8f ival_mask = _mm256_cmp_ps(z, p8f_one, _CMP_GT_OQ);
-
- // Evaluate the polynomial for the interval [1,3] in z.
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_0, 9.999999724233232e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_2, -3.084242535619928e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_4, 1.584991525700324e-02f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_6, -3.188805084631342e-04f);
- Packet8f z_minus_two = psub(z, p8f_two);
- Packet8f z_minus_two2 = pmul(z_minus_two, z_minus_two);
- Packet8f right = pmadd(p8f_coeff_right_6, z_minus_two2, p8f_coeff_right_4);
- right = pmadd(right, z_minus_two2, p8f_coeff_right_2);
- right = pmadd(right, z_minus_two2, p8f_coeff_right_0);
-
- // Evaluate the polynomial for the interval [-1,1] in z.
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_1, 7.853981525427295e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_3, -8.074536727092352e-02f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_5, 2.489871967827018e-03f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_7, -3.587725841214251e-05f);
- Packet8f z2 = pmul(z, z);
- Packet8f left = pmadd(p8f_coeff_left_7, z2, p8f_coeff_left_5);
- left = pmadd(left, z2, p8f_coeff_left_3);
- left = pmadd(left, z2, p8f_coeff_left_1);
- left = pmul(left, z);
-
- // Assemble the results, i.e. select the left and right polynomials.
- left = _mm256_andnot_ps(ival_mask, left);
- right = _mm256_and_ps(ival_mask, right);
- Packet8f res = _mm256_or_ps(left, right);
-
- // Flip the sign on the odd intervals and return the result.
- res = _mm256_xor_ps(res, _mm256_castsi256_ps(sign_flip_mask));
- return res;
+ return psin_float(_x);
}
-// Natural logarithm
-// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
-// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
-// be easily approximated by a polynomial centered on m=1 for stability.
-// TODO(gonnet): Further reduce the interval allowing for lower-degree
-// polynomial interpolants -> ... -> profit!
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
+pcos<Packet8f>(const Packet8f& _x) {
+ return pcos_float(_x);
+}
+
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
plog<Packet8f>(const Packet8f& _x) {
- Packet8f x = _x;
- _EIGEN_DECLARE_CONST_Packet8f(1, 1.0f);
- _EIGEN_DECLARE_CONST_Packet8f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet8f(126f, 126.0f);
+ return plog_float(_x);
+}
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inv_mant_mask, ~0x7f800000);
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
+plog<Packet4d>(const Packet4d& _x) {
+ return plog_double(_x);
+}
- // The smallest non denormalized float number.
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(min_norm_pos, 0x00800000);
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(minus_inf, 0xff800000);
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
+plog2<Packet8f>(const Packet8f& _x) {
+ return plog2_float(_x);
+}
- // Polynomial coefficients.
- _EIGEN_DECLARE_CONST_Packet8f(cephes_SQRTHF, 0.707106781186547524f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p0, 7.0376836292E-2f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p1, -1.1514610310E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p2, 1.1676998740E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p3, -1.2420140846E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p4, +1.4249322787E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p5, -1.6668057665E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p6, +2.0000714765E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p7, -2.4999993993E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p8, +3.3333331174E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_q1, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_q2, 0.693359375f);
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
+plog2<Packet4d>(const Packet4d& _x) {
+ return plog2_double(_x);
+}
- Packet8f invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_NGE_UQ); // not greater equal is true if x is NaN
- Packet8f iszero_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_EQ_OQ);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f plog1p<Packet8f>(const Packet8f& _x) {
+ return generic_plog1p(_x);
+}
- // Truncate input values to the minimum positive normal.
- x = pmax(x, p8f_min_norm_pos);
-
- Packet8f emm0 = pshiftright(x,23);
- Packet8f e = _mm256_sub_ps(emm0, p8f_126f);
-
- // Set the exponents to -1, i.e. x are in the range [0.5,1).
- x = _mm256_and_ps(x, p8f_inv_mant_mask);
- x = _mm256_or_ps(x, p8f_half);
-
- // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
- // and shift by -1. The values are then centered around 0, which improves
- // the stability of the polynomial evaluation.
- // if( x < SQRTHF ) {
- // e -= 1;
- // x = x + x - 1.0;
- // } else { x = x - 1.0; }
- Packet8f mask = _mm256_cmp_ps(x, p8f_cephes_SQRTHF, _CMP_LT_OQ);
- Packet8f tmp = _mm256_and_ps(x, mask);
- x = psub(x, p8f_1);
- e = psub(e, _mm256_and_ps(p8f_1, mask));
- x = padd(x, tmp);
-
- Packet8f x2 = pmul(x, x);
- Packet8f x3 = pmul(x2, x);
-
- // Evaluate the polynomial approximant of degree 8 in three parts, probably
- // to improve instruction-level parallelism.
- Packet8f y, y1, y2;
- y = pmadd(p8f_cephes_log_p0, x, p8f_cephes_log_p1);
- y1 = pmadd(p8f_cephes_log_p3, x, p8f_cephes_log_p4);
- y2 = pmadd(p8f_cephes_log_p6, x, p8f_cephes_log_p7);
- y = pmadd(y, x, p8f_cephes_log_p2);
- y1 = pmadd(y1, x, p8f_cephes_log_p5);
- y2 = pmadd(y2, x, p8f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- // Add the logarithm of the exponent back to the result of the interpolation.
- y1 = pmul(e, p8f_cephes_log_q1);
- tmp = pmul(x2, p8f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p8f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
-
- // Filter out invalid inputs, i.e. negative arg will be NAN, 0 will be -INF.
- return _mm256_or_ps(
- _mm256_andnot_ps(iszero_mask, _mm256_or_ps(x, invalid_mask)),
- _mm256_and_ps(iszero_mask, p8f_minus_inf));
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f pexpm1<Packet8f>(const Packet8f& _x) {
+ return generic_expm1(_x);
}
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -207,149 +70,21 @@
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
pexp<Packet8f>(const Packet8f& _x) {
- _EIGEN_DECLARE_CONST_Packet8f(1, 1.0f);
- _EIGEN_DECLARE_CONST_Packet8f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet8f(127, 127.0f);
-
- _EIGEN_DECLARE_CONST_Packet8f(exp_hi, 88.3762626647950f);
- _EIGEN_DECLARE_CONST_Packet8f(exp_lo, -88.3762626647949f);
-
- _EIGEN_DECLARE_CONST_Packet8f(cephes_LOG2EF, 1.44269504088896341f);
-
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p0, 1.9875691500E-4f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p1, 1.3981999507E-3f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p2, 8.3334519073E-3f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p3, 4.1665795894E-2f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p4, 1.6666665459E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p5, 5.0000001201E-1f);
-
- // Clamp x.
- Packet8f x = pmax(pmin(_x, p8f_exp_hi), p8f_exp_lo);
-
- // Express exp(x) as exp(m*ln(2) + r), start by extracting
- // m = floor(x/ln(2) + 0.5).
- Packet8f m = _mm256_floor_ps(pmadd(x, p8f_cephes_LOG2EF, p8f_half));
-
-// Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
-// subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
-// truncation errors. Note that we don't use the "pmadd" function here to
-// ensure that a precision-preserving FMA instruction is used.
-#ifdef EIGEN_VECTORIZE_FMA
- _EIGEN_DECLARE_CONST_Packet8f(nln2, -0.6931471805599453f);
- Packet8f r = _mm256_fmadd_ps(m, p8f_nln2, x);
-#else
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C1, 0.693359375f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C2, -2.12194440e-4f);
- Packet8f r = psub(x, pmul(m, p8f_cephes_exp_C1));
- r = psub(r, pmul(m, p8f_cephes_exp_C2));
-#endif
-
- Packet8f r2 = pmul(r, r);
-
- // TODO(gonnet): Split into odd/even polynomials and try to exploit
- // instruction-level parallelism.
- Packet8f y = p8f_cephes_exp_p0;
- y = pmadd(y, r, p8f_cephes_exp_p1);
- y = pmadd(y, r, p8f_cephes_exp_p2);
- y = pmadd(y, r, p8f_cephes_exp_p3);
- y = pmadd(y, r, p8f_cephes_exp_p4);
- y = pmadd(y, r, p8f_cephes_exp_p5);
- y = pmadd(y, r2, r);
- y = padd(y, p8f_1);
-
- // Build emm0 = 2^m.
- Packet8i emm0 = _mm256_cvttps_epi32(padd(m, p8f_127));
- emm0 = pshiftleft(emm0, 23);
-
- // Return 2^m * exp(r).
- return pmax(pmul(y, _mm256_castsi256_ps(emm0)), _x);
+ return pexp_float(_x);
}
// Hyperbolic Tangent function.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-ptanh<Packet8f>(const Packet8f& x) {
- return internal::generic_fast_tanh_float(x);
+ptanh<Packet8f>(const Packet8f& _x) {
+ return internal::generic_fast_tanh_float(_x);
}
+// Exponential function for doubles.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
pexp<Packet4d>(const Packet4d& _x) {
- Packet4d x = _x;
-
- _EIGEN_DECLARE_CONST_Packet4d(1, 1.0);
- _EIGEN_DECLARE_CONST_Packet4d(2, 2.0);
- _EIGEN_DECLARE_CONST_Packet4d(half, 0.5);
-
- _EIGEN_DECLARE_CONST_Packet4d(exp_hi, 709.437);
- _EIGEN_DECLARE_CONST_Packet4d(exp_lo, -709.436139303);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_LOG2EF, 1.4426950408889634073599);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_p0, 1.26177193074810590878e-4);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_p1, 3.02994407707441961300e-2);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_p2, 9.99999999999999999910e-1);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q0, 3.00198505138664455042e-6);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q1, 2.52448340349684104192e-3);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q2, 2.27265548208155028766e-1);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q3, 2.00000000000000000009e0);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_C1, 0.693145751953125);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_C2, 1.42860682030941723212e-6);
- _EIGEN_DECLARE_CONST_Packet4i(1023, 1023);
-
- Packet4d tmp, fx;
-
- // clamp x
- x = pmax(pmin(x, p4d_exp_hi), p4d_exp_lo);
- // Express exp(x) as exp(g + n*log(2)).
- fx = pmadd(p4d_cephes_LOG2EF, x, p4d_half);
-
- // Get the integer modulus of log(2), i.e. the "n" described above.
- fx = _mm256_floor_pd(fx);
-
- // Get the remainder modulo log(2), i.e. the "g" described above. Subtract
- // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last
- // digits right.
- tmp = pmul(fx, p4d_cephes_exp_C1);
- Packet4d z = pmul(fx, p4d_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- Packet4d x2 = pmul(x, x);
-
- // Evaluate the numerator polynomial of the rational interpolant.
- Packet4d px = p4d_cephes_exp_p0;
- px = pmadd(px, x2, p4d_cephes_exp_p1);
- px = pmadd(px, x2, p4d_cephes_exp_p2);
- px = pmul(px, x);
-
- // Evaluate the denominator polynomial of the rational interpolant.
- Packet4d qx = p4d_cephes_exp_q0;
- qx = pmadd(qx, x2, p4d_cephes_exp_q1);
- qx = pmadd(qx, x2, p4d_cephes_exp_q2);
- qx = pmadd(qx, x2, p4d_cephes_exp_q3);
-
- // I don't really get this bit, copied from the SSE2 routines, so...
- // TODO(gonnet): Figure out what is going on here, perhaps find a better
- // rational interpolant?
- x = _mm256_div_pd(px, psub(qx, px));
- x = pmadd(p4d_2, x, p4d_1);
-
- // Build e=2^n by constructing the exponents in a 128-bit vector and
- // shifting them to where they belong in double-precision values.
- __m128i emm0 = _mm256_cvtpd_epi32(fx);
- emm0 = _mm_add_epi32(emm0, p4i_1023);
- emm0 = _mm_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0));
- __m128i lo = _mm_slli_epi64(emm0, 52);
- __m128i hi = _mm_slli_epi64(_mm_srli_epi64(emm0, 32), 52);
- __m256i e = _mm256_insertf128_si256(_mm256_setzero_si256(), lo, 0);
- e = _mm256_insertf128_si256(e, hi, 1);
-
- // Construct the result 2^n * exp(g) = e * x. The max is used to catch
- // non-finite values in the input.
- return pmax(pmul(x, _mm256_castsi256_pd(e)), _x);
+ return pexp_double(_x);
}
// Functions for sqrt.
@@ -362,37 +97,39 @@
// For detail see here: http://www.beyond3d.com/content/articles/8/
#if EIGEN_FAST_MATH
template <>
-EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-psqrt<Packet8f>(const Packet8f& _x) {
- Packet8f half = pmul(_x, pset1<Packet8f>(.5f));
- Packet8f denormal_mask = _mm256_and_ps(
- _mm256_cmp_ps(_x, pset1<Packet8f>((std::numeric_limits<float>::min)()),
- _CMP_LT_OQ),
- _mm256_cmp_ps(_x, _mm256_setzero_ps(), _CMP_GE_OQ));
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f psqrt<Packet8f>(const Packet8f& _x) {
+ Packet8f minus_half_x = pmul(_x, pset1<Packet8f>(-0.5f));
+ Packet8f denormal_mask = pandnot(
+ pcmp_lt(_x, pset1<Packet8f>((std::numeric_limits<float>::min)())),
+ pcmp_lt(_x, pzero(_x)));
// Compute approximate reciprocal sqrt.
Packet8f x = _mm256_rsqrt_ps(_x);
// Do a single step of Newton's iteration.
- x = pmul(x, psub(pset1<Packet8f>(1.5f), pmul(half, pmul(x,x))));
+ x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1<Packet8f>(1.5f)));
// Flush results for denormals to zero.
- return _mm256_andnot_ps(denormal_mask, pmul(_x,x));
+ return pandnot(pmul(_x,x), denormal_mask);
}
-#else
-template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f psqrt<Packet8f>(const Packet8f& x) {
- return _mm256_sqrt_ps(x);
-}
-#endif
-template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d psqrt<Packet4d>(const Packet4d& x) {
- return _mm256_sqrt_pd(x);
-}
-#if EIGEN_FAST_MATH
+#else
+
+template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f psqrt<Packet8f>(const Packet8f& _x) {
+ return _mm256_sqrt_ps(_x);
+}
+
+#endif
+
+template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4d psqrt<Packet4d>(const Packet4d& _x) {
+ return _mm256_sqrt_pd(_x);
+}
+
+#if EIGEN_FAST_MATH
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inf, 0x7f800000);
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(nan, 0x7fc00000);
_EIGEN_DECLARE_CONST_Packet8f(one_point_five, 1.5f);
_EIGEN_DECLARE_CONST_Packet8f(minus_half, -0.5f);
_EIGEN_DECLARE_CONST_Packet8f_FROM_INT(flt_min, 0x00800000);
@@ -401,36 +138,88 @@
// select only the inverse sqrt of positive normal inputs (denormals are
// flushed to zero and cause infs as well).
- Packet8f le_zero_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_LT_OQ);
- Packet8f x = _mm256_andnot_ps(le_zero_mask, _mm256_rsqrt_ps(_x));
+ Packet8f lt_min_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_LT_OQ);
+ Packet8f inf_mask = _mm256_cmp_ps(_x, p8f_inf, _CMP_EQ_OQ);
+ Packet8f not_normal_finite_mask = _mm256_or_ps(lt_min_mask, inf_mask);
- // Fill in NaNs and Infs for the negative/zero entries.
- Packet8f neg_mask = _mm256_cmp_ps(_x, _mm256_setzero_ps(), _CMP_LT_OQ);
- Packet8f zero_mask = _mm256_andnot_ps(neg_mask, le_zero_mask);
- Packet8f infs_and_nans = _mm256_or_ps(_mm256_and_ps(neg_mask, p8f_nan),
- _mm256_and_ps(zero_mask, p8f_inf));
+ // Compute an approximate result using the rsqrt intrinsic.
+ Packet8f y_approx = _mm256_rsqrt_ps(_x);
- // Do a single step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p8f_one_point_five));
+ // Do a single step of Newton-Raphson iteration to improve the approximation.
+ // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
+ // It is essential to evaluate the inner term like this because forming
+ // y_n^2 may over- or underflow.
+ Packet8f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p8f_one_point_five));
- // Insert NaNs and Infs in all the right places.
- return _mm256_or_ps(x, infs_and_nans);
+ // Select the result of the Newton-Raphson step for positive normal arguments.
+ // For other arguments, choose the output of the intrinsic. This will
+ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
+ // x is zero or a positive denormalized float (equivalent to flushing positive
+ // denormalized inputs to zero).
+ return pselect<Packet8f>(not_normal_finite_mask, y_approx, y_newton);
}
#else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f prsqrt<Packet8f>(const Packet8f& x) {
+Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
- return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
+ return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x));
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d prsqrt<Packet4d>(const Packet4d& x) {
+Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
- return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
+ return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
}
+F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog2)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) {
+ Packet8f fexponent;
+ const Packet8h out = float2half(pfrexp<Packet8f>(half2float(a), fexponent));
+ exponent = float2half(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) {
+ return float2half(pldexp<Packet8f>(half2float(a), half2float(exponent)));
+}
+
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) {
+ Packet8f fexponent;
+ const Packet8bf out = F32ToBf16(pfrexp<Packet8f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) {
+ return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/PacketMath.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/PacketMath.h
index 923a124..7fc32fd 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -18,11 +18,11 @@
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
-#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
-#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*))
+#if !defined(EIGEN_VECTORIZE_AVX512) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS)
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
#endif
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
@@ -31,10 +31,14 @@
typedef __m256 Packet8f;
typedef __m256i Packet8i;
typedef __m256d Packet4d;
+typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
+typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
template<> struct is_arithmetic<__m256> { enum { value = true }; };
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
+template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
+template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \
const Packet8f p8f_##NAME = pset1<Packet8f>(X)
@@ -58,21 +62,28 @@
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=8,
+ size = 8,
HasHalfPacket = 1,
- HasDiv = 1,
- HasSin = EIGEN_FAST_MATH,
- HasCos = 0,
- HasLog = 1,
- HasExp = 1,
+ HasCmp = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasTanh = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
HasBlend = 1,
HasRound = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+ HasRint = 1
};
};
template<> struct packet_traits<double> : default_packet_traits
@@ -85,14 +96,104 @@
size=4,
HasHalfPacket = 1,
+ HasCmp = 1,
HasDiv = 1,
+ HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1,
HasRound = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+ HasRint = 1
+ };
+};
+
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet8h type;
+ // There is no half-size packet for Packet8h.
+ typedef Packet8h half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasBessel = 1,
+ HasNdtri = 1
+ };
+};
+
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet8bf type;
+ // There is no half-size packet for current Packet8bf.
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasBessel = 1,
+ HasNdtri = 1
};
};
#endif
@@ -113,14 +214,45 @@
};
*/
-template<> struct unpacket_traits<Packet8f> { typedef float type; typedef Packet4f half; enum {size=8, alignment=Aligned32}; };
-template<> struct unpacket_traits<Packet4d> { typedef double type; typedef Packet2d half; enum {size=4, alignment=Aligned32}; };
-template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32}; };
+template<> struct unpacket_traits<Packet8f> {
+ typedef float type;
+ typedef Packet4f half;
+ typedef Packet8i integer_packet;
+ typedef uint8_t mask_t;
+ enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true};
+};
+template<> struct unpacket_traits<Packet4d> {
+ typedef double type;
+ typedef Packet2d half;
+ enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; };
+template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; };
+
+// Helper function for bit packing snippet of low precision comparison.
+// It packs the flags from 16x16 to 8x16.
+EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) {
+ return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
+ _mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
+}
+
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i pset1<Packet8i>(const int& from) { return _mm256_set1_epi32(from); }
+template<> EIGEN_STRONG_INLINE Packet8f pset1frombits<Packet8f>(unsigned int from) { return _mm256_castsi256_ps(pset1<Packet8i>(from)); }
+template<> EIGEN_STRONG_INLINE Packet4d pset1frombits<Packet4d>(uint64_t from) { return _mm256_castsi256_pd(_mm256_set1_epi64x(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _mm256_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); }
+
+
+template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return _mm256_castsi256_ps(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); }
+template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return _mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1); }
+template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return _mm256_castsi256_pd(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); }
+
template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { return _mm256_broadcast_ss(from); }
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }
@@ -129,9 +261,27 @@
template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i padd<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_add_epi32(a,b);
+#else
+ __m128i lo = _mm_add_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_add_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f psub<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d psub<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_sub_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i psub<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_sub_epi32(a,b);
+#else
+ __m128i lo = _mm_sub_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_sub_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pnegate(const Packet8f& a)
{
@@ -148,7 +298,15 @@
template<> EIGEN_STRONG_INLINE Packet8f pmul<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_mul_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pmul<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_mul_pd(a,b); }
-
+template<> EIGEN_STRONG_INLINE Packet8i pmul<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_mullo_epi32(a,b);
+#else
+ const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pdiv<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pdiv<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); }
@@ -157,7 +315,7 @@
return pset1<Packet8i>(0);
}
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
// Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers,
@@ -184,14 +342,112 @@
}
#endif
-template<> EIGEN_STRONG_INLINE Packet8f pmin<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_min_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4d pmin<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_min_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
-template<> EIGEN_STRONG_INLINE Packet8f pmax<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_max_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_max_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
-template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
-template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
+
+template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_cmpeq_epi32(a,b);
+#else
+ __m128i lo = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pmin<Packet8f>(const Packet8f& a, const Packet8f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may flip
+ // the argument order in calls to _mm_min_ps/_mm_max_ps, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ Packet8f res;
+ asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::min.
+ return _mm256_min_ps(b,a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4d pmin<Packet4d>(const Packet4d& a, const Packet4d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // See pmin above
+ Packet4d res;
+ asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::min.
+ return _mm256_min_pd(b,a);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pmax<Packet8f>(const Packet8f& a, const Packet8f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // See pmin above
+ Packet8f res;
+ asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::max.
+ return _mm256_max_ps(b,a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const Packet4d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // See pmin above
+ Packet4d res;
+ asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::max.
+ return _mm256_max_pd(b,a);
+#endif
+}
+
+// Add specializations for min/max with prescribed NaN progation.
+template<>
+EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmin<PropagateNumbers, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet4d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8f pmax<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmax<PropagateNumbers, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet4d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8f pmin<PropagateNaN, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmin<PropagateNaN, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet4d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8f pmax<PropagateNaN, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmax<PropagateNaN, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet4d>);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f print<Packet8f>(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet4d print<Packet4d>(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); }
@@ -199,17 +455,124 @@
template<> EIGEN_STRONG_INLINE Packet8f pfloor<Packet8f>(const Packet8f& a) { return _mm256_floor_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pfloor<Packet4d>(const Packet4d& a) { return _mm256_floor_pd(a); }
+
+template<> EIGEN_STRONG_INLINE Packet8i ptrue<Packet8i>(const Packet8i& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ // vpcmpeqd has lower latency than the more general vcmpps
+ return _mm256_cmpeq_epi32(a,a);
+#else
+ const __m256 b = _mm256_castsi256_ps(a);
+ return _mm256_castps_si256(_mm256_cmp_ps(b,b,_CMP_TRUE_UQ));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f ptrue<Packet8f>(const Packet8f& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ // vpcmpeqd has lower latency than the more general vcmpps
+ const __m256i b = _mm256_castps_si256(a);
+ return _mm256_castsi256_ps(_mm256_cmpeq_epi32(b,b));
+#else
+ return _mm256_cmp_ps(a,a,_CMP_TRUE_UQ);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet4d ptrue<Packet4d>(const Packet4d& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ // vpcmpeqq has lower latency than the more general vcmppd
+ const __m256i b = _mm256_castpd_si256(a);
+ return _mm256_castsi256_pd(_mm256_cmpeq_epi64(b,b));
+#else
+ return _mm256_cmp_pd(a,a,_CMP_TRUE_UQ);
+#endif
+}
+
template<> EIGEN_STRONG_INLINE Packet8f pand<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_and_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pand<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_and_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i pand<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_and_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f por<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_or_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d por<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_or_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i por<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_or_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pxor<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_xor_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pxor<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_xor_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i pxor<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_xor_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
-template<> EIGEN_STRONG_INLINE Packet8f pandnot<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4d pandnot<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8f pandnot<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(b,a); }
+template<> EIGEN_STRONG_INLINE Packet4d pandnot<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(b,a); }
+template<> EIGEN_STRONG_INLINE Packet8i pandnot<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_andnot_si256(b,a);
+#else
+ return _mm256_castps_si256(_mm256_andnot_ps(_mm256_castsi256_ps(b),_mm256_castsi256_ps(a)));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a)
+{
+ const Packet8f mask = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x80000000u));
+ const Packet8f prev0dot5 = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
+ return _mm256_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a)
+{
+ const Packet4d mask = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
+ const Packet4d prev0dot5 = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
+ return _mm256_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pselect<Packet8f>(const Packet8f& mask, const Packet8f& a, const Packet8f& b)
+{ return _mm256_blendv_ps(b,a,mask); }
+template<> EIGEN_STRONG_INLINE Packet4d pselect<Packet4d>(const Packet4d& mask, const Packet4d& a, const Packet4d& b)
+{ return _mm256_blendv_pd(b,a,mask); }
+
+template<int N> EIGEN_STRONG_INLINE Packet8i parithmetic_shift_right(Packet8i a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_srai_epi32(a, N);
+#else
+ __m128i lo = _mm_srai_epi32(_mm256_extractf128_si256(a, 0), N);
+ __m128i hi = _mm_srai_epi32(_mm256_extractf128_si256(a, 1), N);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet8i plogical_shift_right(Packet8i a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_srli_epi32(a, N);
+#else
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(a, 0), N);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(a, 1), N);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet8i plogical_shift_left(Packet8i a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_slli_epi32(a, N);
+#else
+ __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(a, 0), N);
+ __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(a, 1), N);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pload<Packet8f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pload<Packet4d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_pd(from); }
@@ -219,6 +582,14 @@
template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
+ Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
+ const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
+ mask = por<Packet8i>(mask, bit_mask);
+ mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask);
+}
+
// Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3}
template<> EIGEN_STRONG_INLINE Packet8f ploaddup<Packet8f>(const float* from)
{
@@ -226,7 +597,7 @@
// Packet8f tmp = _mm256_castps128_ps256(_mm_loadu_ps(from));
// tmp = _mm256_insertf128_ps(tmp, _mm_movehl_ps(_mm256_castps256_ps128(tmp),_mm256_castps256_ps128(tmp)), 1);
// return _mm256_unpacklo_ps(tmp,tmp);
-
+
// _mm256_insertf128_ps is very slow on Haswell, thus:
Packet8f tmp = _mm256_broadcast_ps((const __m128*)(const void*)from);
// mimic an "inplace" permutation of the lower 128bits using a blend
@@ -256,6 +627,14 @@
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
+ Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
+ const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
+ mask = por<Packet8i>(mask, bit_mask);
+ mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from);
+}
+
// NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available
// NOTE: for the record the following seems to be slower: return _mm256_i32gather_ps(from, _mm256_set1_epi32(stride), 4);
template<> EIGEN_DEVICE_FUNC inline Packet8f pgather<float, Packet8f>(const float* from, Index stride)
@@ -354,47 +733,66 @@
return _mm256_and_pd(a,mask);
}
-// preduxp should be ok
-// FIXME: why is this ok? why isn't the simply implementation working as expected?
-template<> EIGEN_STRONG_INLINE Packet8f preduxp<Packet8f>(const Packet8f* vecs)
-{
- __m256 hsum1 = _mm256_hadd_ps(vecs[0], vecs[1]);
- __m256 hsum2 = _mm256_hadd_ps(vecs[2], vecs[3]);
- __m256 hsum3 = _mm256_hadd_ps(vecs[4], vecs[5]);
- __m256 hsum4 = _mm256_hadd_ps(vecs[6], vecs[7]);
-
- __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1);
- __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2);
- __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3);
- __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4);
-
- __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
- __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
- __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
- __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
-
- __m256 sum1 = _mm256_add_ps(perm1, hsum5);
- __m256 sum2 = _mm256_add_ps(perm2, hsum6);
- __m256 sum3 = _mm256_add_ps(perm3, hsum7);
- __m256 sum4 = _mm256_add_ps(perm4, hsum8);
-
- __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
- __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
-
- __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0);
- return final;
+template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
+ return pfrexp_generic(a,exponent);
}
-template<> EIGEN_STRONG_INLINE Packet4d preduxp<Packet4d>(const Packet4d* vecs)
-{
- Packet4d tmp0, tmp1;
- tmp0 = _mm256_hadd_pd(vecs[0], vecs[1]);
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
+// Extract exponent without existence of Packet4l.
+template<>
+EIGEN_STRONG_INLINE
+Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) {
+ const Packet4d cst_exp_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(0x7ff0000000000000ull));
+ __m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask));
+#ifdef EIGEN_VECTORIZE_AVX2
+ a_expo = _mm256_srli_epi64(a_expo, 52);
+ __m128i lo = _mm256_extractf128_si256(a_expo, 0);
+ __m128i hi = _mm256_extractf128_si256(a_expo, 1);
+#else
+ __m128i lo = _mm256_extractf128_si256(a_expo, 0);
+ __m128i hi = _mm256_extractf128_si256(a_expo, 1);
+ lo = _mm_srli_epi64(lo, 52);
+ hi = _mm_srli_epi64(hi, 52);
+#endif
+ Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3));
+ Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3));
+ Packet4d exponent = _mm256_insertf128_pd(_mm256_setzero_pd(), exponent_lo, 0);
+ exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1);
+ return exponent;
+}
- tmp1 = _mm256_hadd_pd(vecs[2], vecs[3]);
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
- return _mm256_blend_pd(tmp0, tmp1, 0xC);
+template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {
+ return pldexp_generic(a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4d pldexp<Packet4d>(const Packet4d& a, const Packet4d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet4d max_exponent = pset1<Packet4d>(2099.0);
+ const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+
+ // Split 2^e into four factors and multiply.
+ const Packet4i bias = pset1<Packet4i>(1023);
+ Packet4i b = parithmetic_shift_right<2>(e); // floor(e/4)
+
+ // 2^b
+ Packet4i hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3);
+ Packet4i lo = _mm_slli_epi64(hi, 52);
+ hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52);
+ Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1));
+ Packet4d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+
+ // 2^(e - 3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3);
+ lo = _mm_slli_epi64(hi, 52);
+ hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52);
+ c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1));
+ out = pmul(out, c); // a * 2^e
+ return out;
}
template<> EIGEN_STRONG_INLINE float predux<Packet8f>(const Packet8f& a)
@@ -406,7 +804,7 @@
return predux(Packet2d(_mm_add_pd(_mm256_castpd256_pd128(a),_mm256_extractf128_pd(a,1))));
}
-template<> EIGEN_STRONG_INLINE Packet4f predux_downto4<Packet8f>(const Packet8f& a)
+template<> EIGEN_STRONG_INLINE Packet4f predux_half_dowto4<Packet8f>(const Packet8f& a)
{
return _mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1));
}
@@ -450,93 +848,16 @@
return pfirst(_mm256_max_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1)));
}
+// not needed yet
+// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet8f& x)
+// {
+// return _mm256_movemask_ps(x)==0xFF;
+// }
-template<int Offset>
-struct palign_impl<Offset,Packet8f>
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet8f& x)
{
- static EIGEN_STRONG_INLINE void run(Packet8f& first, const Packet8f& second)
- {
- if (Offset==1)
- {
- first = _mm256_blend_ps(first, second, 1);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(0,3,2,1));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_blend_ps(tmp1, tmp2, 0x88);
- }
- else if (Offset==2)
- {
- first = _mm256_blend_ps(first, second, 3);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(1,0,3,2));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_blend_ps(tmp1, tmp2, 0xcc);
- }
- else if (Offset==3)
- {
- first = _mm256_blend_ps(first, second, 7);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(2,1,0,3));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_blend_ps(tmp1, tmp2, 0xee);
- }
- else if (Offset==4)
- {
- first = _mm256_blend_ps(first, second, 15);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(3,2,1,0));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_permute_ps(tmp2, _MM_SHUFFLE(3,2,1,0));
- }
- else if (Offset==5)
- {
- first = _mm256_blend_ps(first, second, 31);
- first = _mm256_permute2f128_ps(first, first, 1);
- Packet8f tmp = _mm256_permute_ps (first, _MM_SHUFFLE(0,3,2,1));
- first = _mm256_permute2f128_ps(tmp, tmp, 1);
- first = _mm256_blend_ps(tmp, first, 0x88);
- }
- else if (Offset==6)
- {
- first = _mm256_blend_ps(first, second, 63);
- first = _mm256_permute2f128_ps(first, first, 1);
- Packet8f tmp = _mm256_permute_ps (first, _MM_SHUFFLE(1,0,3,2));
- first = _mm256_permute2f128_ps(tmp, tmp, 1);
- first = _mm256_blend_ps(tmp, first, 0xcc);
- }
- else if (Offset==7)
- {
- first = _mm256_blend_ps(first, second, 127);
- first = _mm256_permute2f128_ps(first, first, 1);
- Packet8f tmp = _mm256_permute_ps (first, _MM_SHUFFLE(2,1,0,3));
- first = _mm256_permute2f128_ps(tmp, tmp, 1);
- first = _mm256_blend_ps(tmp, first, 0xee);
- }
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet4d>
-{
- static EIGEN_STRONG_INLINE void run(Packet4d& first, const Packet4d& second)
- {
- if (Offset==1)
- {
- first = _mm256_blend_pd(first, second, 1);
- __m256d tmp = _mm256_permute_pd(first, 5);
- first = _mm256_permute2f128_pd(tmp, tmp, 1);
- first = _mm256_blend_pd(tmp, first, 0xA);
- }
- else if (Offset==2)
- {
- first = _mm256_blend_pd(first, second, 3);
- first = _mm256_permute2f128_pd(first, first, 1);
- }
- else if (Offset==3)
- {
- first = _mm256_blend_pd(first, second, 7);
- __m256d tmp = _mm256_permute_pd(first, 5);
- first = _mm256_permute2f128_pd(tmp, tmp, 1);
- first = _mm256_blend_pd(tmp, first, 5);
- }
- }
-};
+ return _mm256_movemask_ps(x)!=0;
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8f,8>& kernel) {
@@ -610,24 +931,640 @@
return _mm256_blendv_pd(thenPacket, elsePacket, false_mask);
}
-template<> EIGEN_STRONG_INLINE Packet8f pinsertfirst(const Packet8f& a, float b)
-{
- return _mm256_blend_ps(a,pset1<Packet8f>(b),1);
+// Packet math for Eigen::half
+
+template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; };
+
+template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
+ return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from));
}
-template<> EIGEN_STRONG_INLINE Packet4d pinsertfirst(const Packet4d& a, double b)
-{
- return _mm256_blend_pd(a,pset1<Packet4d>(b),1);
+template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
+ return numext::bit_cast<Eigen::half>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0)));
}
-template<> EIGEN_STRONG_INLINE Packet8f pinsertlast(const Packet8f& a, float b)
-{
- return _mm256_blend_ps(a,pset1<Packet8f>(b),(1<<7));
+template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
+ return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
}
-template<> EIGEN_STRONG_INLINE Packet4d pinsertlast(const Packet4d& a, double b)
+template<> EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8h& from) {
+ _mm_store_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8h& from) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h
+ploaddup<Packet8h>(const Eigen::half* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]);
+ const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]);
+ return _mm_set_epi16(d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h
+ploadquad<Packet8h>(const Eigen::half* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ return _mm_set_epi16(b, b, b, b, a, a, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) {
+ const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_andnot_si128(sign_mask, a);
+}
+
+EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
+#ifdef EIGEN_HAS_FP16_C
+ return _mm256_cvtph_ps(a);
+#else
+ EIGEN_ALIGN32 Eigen::half aux[8];
+ pstore(aux, a);
+ float f0(aux[0]);
+ float f1(aux[1]);
+ float f2(aux[2]);
+ float f3(aux[3]);
+ float f4(aux[4]);
+ float f5(aux[5]);
+ float f6(aux[6]);
+ float f7(aux[7]);
+
+ return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0);
+#endif
+}
+
+EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
+#ifdef EIGEN_HAS_FP16_C
+ return _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
+#else
+ EIGEN_ALIGN32 float aux[8];
+ pstore(aux, a);
+ const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[0]));
+ const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[1]));
+ const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[2]));
+ const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[3]));
+ const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[4]));
+ const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[5]));
+ const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[6]));
+ const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[7]));
+ return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a,
+ const Packet8h& b) {
+ return float2half(pmin<Packet8f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a,
+ const Packet8h& b) {
+ return float2half(pmax<Packet8f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
+ return float2half(plset<Packet8f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) {
+ // in some cases Packet4i is a wrapper around __m128i, so we either need to
+ // cast to Packet4i to directly call the intrinsics as below:
+ return _mm_or_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a,const Packet8h& b) {
+ return _mm_xor_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a,const Packet8h& b) {
+ return _mm_and_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h& b) {
+ return _mm_andnot_si128(b,a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
+ return _mm_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
+ return float2half(pround<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
+ return float2half(print<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
+ return float2half(pceil<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
+ return float2half(pfloor<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_le(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) {
+ Packet8h sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_xor_si128(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = padd(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = psub(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = pmul(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = pdiv(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride)
{
- return _mm256_blend_pd(a,pset1<Packet4d>(b),(1<<3));
+ const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]);
+ const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]);
+ const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]);
+ const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]);
+ const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]);
+ const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]);
+ const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]);
+ const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]);
+ return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride)
+{
+ EIGEN_ALIGN32 Eigen::half aux[8];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux_max<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux_min<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux_mul<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a)
+{
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm_shuffle_epi8(a,m);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8h,8>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+ __m128i e = kernel.packet[4];
+ __m128i f = kernel.packet[5];
+ __m128i g = kernel.packet[6];
+ __m128i h = kernel.packet[7];
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+
+ kernel.packet[0] = a0b0c0d0e0f0g0h0;
+ kernel.packet[1] = a1b1c1d1e1f1g1h1;
+ kernel.packet[2] = a2b2c2d2e2f2g2h2;
+ kernel.packet[3] = a3b3c3d3e3f3g3h3;
+ kernel.packet[4] = a4b4c4d4e4f4g4h4;
+ kernel.packet[5] = a5b5c5d5e5f5g5h5;
+ kernel.packet[6] = a6b6c6d6e6f6g6h6;
+ kernel.packet[7] = a7b7c7d7e7f7g7h7;
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8h,4>& kernel) {
+ EIGEN_ALIGN32 Eigen::half in[4][8];
+ pstore<Eigen::half>(in[0], kernel.packet[0]);
+ pstore<Eigen::half>(in[1], kernel.packet[1]);
+ pstore<Eigen::half>(in[2], kernel.packet[2]);
+ pstore<Eigen::half>(in[3], kernel.packet[3]);
+
+ EIGEN_ALIGN32 Eigen::half out[4][8];
+
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ out[i][j] = in[j][2*i];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j+4] = in[j][2*i+1];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet8h>(out[0]);
+ kernel.packet[1] = pload<Packet8h>(out[1]);
+ kernel.packet[2] = pload<Packet8h>(out[2]);
+ kernel.packet[3] = pload<Packet8h>(out[3]);
+}
+
+// BFloat16 implementation.
+
+EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ __m256i extend = _mm256_cvtepu16_epi32(a);
+ return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16));
+#else
+ __m128i lo = _mm_cvtepu16_epi32(a);
+ __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8));
+ __m128i lo_shift = _mm_slli_epi32(lo, 16);
+ __m128i hi_shift = _mm_slli_epi32(hi, 16);
+ return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1));
+#endif
+}
+
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
+ Packet8bf r;
+
+ __m256i input = _mm256_castps_si256(a);
+
+#ifdef EIGEN_VECTORIZE_AVX2
+ // uint32_t lsb = (input >> 16);
+ __m256i t = _mm256_srli_epi32(input, 16);
+ // uint32_t lsb = lsb & 1;
+ t = _mm256_and_si256(t, _mm256_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ t = _mm256_add_epi32(t, input);
+ // input = input >> 16;
+ t = _mm256_srli_epi32(t, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
+ __m256i nan = _mm256_set1_epi32(0x7fc0);
+ t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
+ // output = numext::bit_cast<uint16_t>(input);
+ return _mm_packus_epi32(_mm256_extractf128_si256(t, 0),
+ _mm256_extractf128_si256(t, 1));
+#else
+ // uint32_t lsb = (input >> 16);
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16);
+ // uint32_t lsb = lsb & 1;
+ lo = _mm_and_si128(lo, _mm_set1_epi32(1));
+ hi = _mm_and_si128(hi, _mm_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff));
+ hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0));
+ hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1));
+ // input = input >> 16;
+ lo = _mm_srli_epi32(lo, 16);
+ hi = _mm_srli_epi32(hi, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
+ __m128i nan = _mm_set1_epi32(0x7fc0);
+ lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
+ hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
+ // output = numext::bit_cast<uint16_t>(input);
+ return _mm_packus_epi32(lo, hi);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
+ return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) {
+ return numext::bit_cast<bfloat16>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) {
+ return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) {
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_store_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploaddup<Packet8bf>(const bfloat16* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]);
+ const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]);
+ return _mm_set_epi16(d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploadquad<Packet8bf>(const bfloat16* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ return _mm_set_epi16(b, b, b, b, a, a, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
+ const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_andnot_si128(sign_mask, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
+ return F32ToBf16(plset<Packet8f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_or_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_xor_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_and_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_andnot_si128(b,a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) {
+ return _mm_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a)
+{
+ return F32ToBf16(pround<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(print<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) {
+ Packet8bf sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_xor_si128(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
+{
+ const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]);
+ const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]);
+ const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]);
+ const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]);
+ const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]);
+ const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]);
+ const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]);
+ const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]);
+ return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
+{
+ EIGEN_ALIGN32 bfloat16 aux[8];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
+{
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm_shuffle_epi8(a,m);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,8>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+ __m128i e = kernel.packet[4];
+ __m128i f = kernel.packet[5];
+ __m128i g = kernel.packet[6];
+ __m128i h = kernel.packet[7];
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,4>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+
+ __m128i ab_03 = _mm_unpacklo_epi16(a, b);
+ __m128i cd_03 = _mm_unpacklo_epi16(c, d);
+ __m128i ab_47 = _mm_unpackhi_epi16(a, b);
+ __m128i cd_47 = _mm_unpackhi_epi16(c, d);
+
+ kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03);
+ kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03);
+ kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47);
+ kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47);
}
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/TypeCasting.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/TypeCasting.h
index 83bfdc6..d507fb6 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -35,15 +35,79 @@
};
+#ifndef EIGEN_VECTORIZE_AVX512
+
+template <>
+struct type_casting_traits<Eigen::half, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+
+template <>
+struct type_casting_traits<float, Eigen::half> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+#endif // EIGEN_VECTORIZE_AVX512
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
- return _mm256_cvtps_epi32(a);
+ return _mm256_cvttps_epi32(a);
}
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8i, Packet8f>(const Packet8i& a) {
return _mm256_cvtepi32_ps(a);
}
+template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i,Packet8f>(const Packet8f& a) {
+ return _mm256_castps_si256(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f,Packet8i>(const Packet8i& a) {
+ return _mm256_castsi256_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
+ return half2float(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
+ return Bf16ToF32(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
+ return float2half(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
+ return F32ToBf16(a);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/BFloat16.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/BFloat16.h
new file mode 100644
index 0000000..1c28f4f
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/BFloat16.h
@@ -0,0 +1,700 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef EIGEN_BFLOAT16_H
+#define EIGEN_BFLOAT16_H
+
+#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
+ template <> \
+ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
+ PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
+ return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
+ }
+
+namespace Eigen {
+
+struct bfloat16;
+
+namespace bfloat16_impl {
+
+// Make our own __bfloat16_raw definition.
+struct __bfloat16_raw {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
+ unsigned short value;
+};
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
+template <bool AssumeArgumentIsNormalOrInfinityOrZero>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
+// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
+// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
+
+struct bfloat16_base : public __bfloat16_raw {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
+};
+
+} // namespace bfloat16_impl
+
+// Class definition.
+struct bfloat16 : public bfloat16_impl::bfloat16_base {
+
+ typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
+
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
+
+ template<class T>
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
+
+ explicit EIGEN_DEVICE_FUNC bfloat16(float f)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
+
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ template<typename RealScalar>
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
+
+ EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
+ return bfloat16_impl::bfloat16_to_float(*this);
+ }
+};
+} // namespace Eigen
+
+namespace std {
+template<>
+struct numeric_limits<Eigen::bfloat16> {
+ static const bool is_specialized = true;
+ static const bool is_signed = true;
+ static const bool is_integer = false;
+ static const bool is_exact = false;
+ static const bool has_infinity = true;
+ static const bool has_quiet_NaN = true;
+ static const bool has_signaling_NaN = true;
+ static const float_denorm_style has_denorm = std::denorm_absent;
+ static const bool has_denorm_loss = false;
+ static const std::float_round_style round_style = numeric_limits<float>::round_style;
+ static const bool is_iec559 = false;
+ static const bool is_bounded = true;
+ static const bool is_modulo = false;
+ static const int digits = 8;
+ static const int digits10 = 2;
+ static const int max_digits10 = 4;
+ static const int radix = 2;
+ static const int min_exponent = numeric_limits<float>::min_exponent;
+ static const int min_exponent10 = numeric_limits<float>::min_exponent10;
+ static const int max_exponent = numeric_limits<float>::max_exponent;
+ static const int max_exponent10 = numeric_limits<float>::max_exponent10;
+ static const bool traps = numeric_limits<float>::traps;
+ static const bool tinyness_before = numeric_limits<float>::tinyness_before;
+
+ static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
+ static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
+ static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
+ static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
+ static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
+ static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
+ static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
+ static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
+ static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
+};
+
+// If std::numeric_limits<T> is specialized, should also specialize
+// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
+// std::numeric_limits<const volatile T>
+// https://stackoverflow.com/a/16519653/
+template<>
+struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+template<>
+struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+template<>
+struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+} // namespace std
+
+namespace Eigen {
+
+namespace bfloat16_impl {
+
+// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
+// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
+// of the functions, while the latter can only deal with one of them.
+#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+// We need to provide emulated *host-side* BF16 operators for clang.
+#pragma push_macro("EIGEN_DEVICE_FUNC")
+#undef EIGEN_DEVICE_FUNC
+#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
+#define EIGEN_DEVICE_FUNC __host__
+#else // both host and device need emulated ops.
+#define EIGEN_DEVICE_FUNC __host__ __device__
+#endif
+#endif
+
+// Definitions for CPUs, mostly working through conversion
+// to/from fp32.
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
+ return bfloat16(float(a) + static_cast<float>(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
+ return bfloat16(static_cast<float>(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) * float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) - float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) / float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
+ bfloat16 result;
+ result.value = a.value ^ 0x8000;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) + float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) * float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) - float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) / float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
+ a += bfloat16(1);
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
+ a -= bfloat16(1);
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
+ bfloat16 original_value = a;
+ ++a;
+ return original_value;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
+ bfloat16 original_value = a;
+ --a;
+ return original_value;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
+ return numext::equal_strict(float(a),float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
+ return numext::not_equal_strict(float(a), float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
+ return float(a) < float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
+ return float(a) <= float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
+ return float(a) > float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
+ return float(a) >= float(b);
+}
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+#pragma pop_macro("EIGEN_DEVICE_FUNC")
+#endif
+#endif // Emulate support for bfloat16 floats
+
+// Division by an index. Do it in full float precision to avoid accuracy
+// issues in converting the denominator to bfloat16.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
+ __bfloat16_raw output;
+ if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
+ output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
+ return output;
+ }
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
+#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ output.value = p[0];
+#else
+ output.value = p[1];
+#endif
+ return output;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
+ return __bfloat16_raw(value);
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
+ return bf.value;
+}
+
+// float_to_bfloat16_rtne template specialization that does not make any
+// assumption about the value of its function argument (ff).
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
+#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
+ // Nothing to do here
+#else
+ __bfloat16_raw output;
+
+ if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
+ // If the value is a NaN, squash it to a qNaN with msb of fraction set,
+ // this makes sure after truncation we don't end up with an inf.
+ //
+ // qNaN magic: All exponent bits set + most significant bit of fraction
+ // set.
+ output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
+ } else {
+ // Fast rounding algorithm that rounds a half value to nearest even. This
+ // reduces expected error when we convert a large number of floats. Here
+ // is how it works:
+ //
+ // Definitions:
+ // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
+ // with the following tags:
+ //
+ // Sign | Exp (8 bits) | Frac (23 bits)
+ // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
+ //
+ // S: Sign bit.
+ // E: Exponent bits.
+ // F: First 6 bits of fraction.
+ // L: Least significant bit of resulting bfloat16 if we truncate away the
+ // rest of the float32. This is also the 7th bit of fraction
+ // R: Rounding bit, 8th bit of fraction.
+ // T: Sticky bits, rest of fraction, 15 bits.
+ //
+ // To round half to nearest even, there are 3 cases where we want to round
+ // down (simply truncate the result of the bits away, which consists of
+ // rounding bit and sticky bits) and two cases where we want to round up
+ // (truncate then add one to the result).
+ //
+ // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
+ // 1s) as the rounding bias, adds the rounding bias to the input, then
+ // truncates the last 16 bits away.
+ //
+ // To understand how it works, we can analyze this algorithm case by case:
+ //
+ // 1. L = 0, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input may create any carry, depending on
+ // whether there is any value set to 1 in T bits.
+ // - R may be set to 1 if there is a carry.
+ // - L remains 0.
+ // - Note that this case also handles Inf and -Inf, where all fraction
+ // bits, including L, R and Ts are all 0. The output remains Inf after
+ // this algorithm.
+ //
+ // 2. L = 1, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits but
+ // adds 1 to rounding bit.
+ // - L remains 1.
+ //
+ // 3. L = 0, R = 1, all of T are 0:
+ // Expect: round down, this is exactly at half, the result is already
+ // even (L=0).
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input sets all sticky bits to 1, but
+ // doesn't create a carry.
+ // - R remains 1.
+ // - L remains 0.
+ //
+ // 4. L = 1, R = 1:
+ // Expect: round up, this is exactly at half, the result needs to be
+ // round to the next even number.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits, but
+ // creates a carry from rounding bit.
+ // - The carry sets L to 0, creates another carry bit and propagate
+ // forward to F bits.
+ // - If all the F bits are 1, a carry then propagates to the exponent
+ // bits, which then creates the minimum value with the next exponent
+ // value. Note that we won't have the case where exponents are all 1,
+ // since that's either a NaN (handled in the other if condition) or inf
+ // (handled in case 1).
+ //
+ // 5. L = 0, R = 1, any of T is 1:
+ // Expect: round up, this is greater than half.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input creates a carry from sticky bits,
+ // sets rounding bit to 0, then create another carry.
+ // - The second carry sets L to 1.
+ //
+ // Examples:
+ //
+ // Exact half value that is already even:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
+ //
+ // This falls into case 3. We truncate the rest of 16 bits and no
+ // carry is created into F and L:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ // Exact half value, round to next even number:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // which then propagates into L and F:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ //
+ // Max denormal value round to min normal value:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
+ //
+ // Max normal value round to Inf:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
+
+ // At this point, ff must be either a normal float, or +/-infinity.
+ output = float_to_bfloat16_rtne<true>(ff);
+ }
+ return output;
+#endif
+}
+
+// float_to_bfloat16_rtne template specialization that assumes that its function
+// argument (ff) is either a normal floating point number, or +/-infinity, or
+// zero. Used to improve the runtime performance of conversion from an integer
+// type to bfloat16.
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
+#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
+ // Nothing to do here
+#else
+ numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
+ __bfloat16_raw output;
+
+ // Least significant bit of resulting bfloat.
+ numext::uint32_t lsb = (input >> 16) & 1;
+ numext::uint32_t rounding_bias = 0x7fff + lsb;
+ input += rounding_bias;
+ output.value = static_cast<numext::uint16_t>(input >> 16);
+ return output;
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
+ float result = 0;
+ unsigned short* q = reinterpret_cast<unsigned short*>(&result);
+#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ q[0] = h.value;
+#else
+ q[1] = h.value;
+#endif
+ return result;
+}
+// --- standard functions ---
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
+ EIGEN_USING_STD(isinf);
+ return (isinf)(float(a));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
+ EIGEN_USING_STD(isnan);
+ return (isnan)(float(a));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
+ return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
+ bfloat16 result;
+ result.value = a.value & 0x7FFF;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
+ return bfloat16(::expf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
+ return bfloat16(numext::expm1(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
+ return bfloat16(::logf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
+ return bfloat16(numext::log1p(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
+ return bfloat16(::log10f(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
+ return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
+ return bfloat16(::sqrtf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
+ return bfloat16(::powf(float(a), float(b)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
+ return bfloat16(::sinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
+ return bfloat16(::cosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
+ return bfloat16(::tanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
+ return bfloat16(::asinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
+ return bfloat16(::acosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
+ return bfloat16(::atanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
+ return bfloat16(::sinhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
+ return bfloat16(::coshf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
+ return bfloat16(::tanhf(float(a)));
+}
+#if EIGEN_HAS_CXX11_MATH
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
+ return bfloat16(::asinhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
+ return bfloat16(::acoshf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
+ return bfloat16(::atanhf(float(a)));
+}
+#endif
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
+ return bfloat16(::floorf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
+ return bfloat16(::ceilf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
+ return bfloat16(::rintf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
+ return bfloat16(::roundf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
+ return bfloat16(::fmodf(float(a), float(b)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f2 < f1 ? b : a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f1 < f2 ? b : a;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return bfloat16(::fminf(f1, f2));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return bfloat16(::fmaxf(f1, f2));
+}
+
+#ifndef EIGEN_NO_IO
+EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
+ os << static_cast<float>(v);
+ return os;
+}
+#endif
+
+} // namespace bfloat16_impl
+
+namespace internal {
+
+template<>
+struct random_default_impl<bfloat16, false, false>
+{
+ static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
+ {
+ return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
+ }
+ static inline bfloat16 run()
+ {
+ return run(bfloat16(-1.f), bfloat16(1.f));
+ }
+};
+
+template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
+
+} // namespace internal
+
+template<> struct NumTraits<Eigen::bfloat16>
+ : GenericNumTraits<Eigen::bfloat16>
+{
+ enum {
+ IsSigned = true,
+ IsInteger = false,
+ IsComplex = false,
+ RequireInitialization = false
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
+
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
+ }
+};
+
+} // namespace Eigen
+
+namespace Eigen {
+namespace numext {
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isnan)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isnan)(h);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isinf)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isinf)(h);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isfinite)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isfinite)(h);
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
+ return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
+ return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
+}
+
+} // namespace numext
+} // namespace Eigen
+
+#if EIGEN_HAS_STD_HASH
+namespace std {
+template <>
+struct hash<Eigen::bfloat16> {
+ EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
+ return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
+ }
+};
+} // namespace std
+#endif
+
+
+#endif // EIGEN_BFLOAT16_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/ConjHelper.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/ConjHelper.h
index 4cfe34e..53830b5 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/ConjHelper.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/ConjHelper.h
@@ -11,19 +11,107 @@
#ifndef EIGEN_ARCH_CONJ_HELPER_H
#define EIGEN_ARCH_CONJ_HELPER_H
-#define EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(PACKET_CPLX, PACKET_REAL) \
- template<> struct conj_helper<PACKET_REAL, PACKET_CPLX, false,false> { \
- EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_REAL& x, const PACKET_CPLX& y, const PACKET_CPLX& c) const \
- { return padd(c, pmul(x,y)); } \
- EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_REAL& x, const PACKET_CPLX& y) const \
- { return PACKET_CPLX(Eigen::internal::pmul<PACKET_REAL>(x, y.v)); } \
- }; \
- \
- template<> struct conj_helper<PACKET_CPLX, PACKET_REAL, false,false> { \
- EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_CPLX& x, const PACKET_REAL& y, const PACKET_CPLX& c) const \
- { return padd(c, pmul(x,y)); } \
- EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_CPLX& x, const PACKET_REAL& y) const \
- { return PACKET_CPLX(Eigen::internal::pmul<PACKET_REAL>(x.v, y)); } \
+#define EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(PACKET_CPLX, PACKET_REAL) \
+ template <> \
+ struct conj_helper<PACKET_REAL, PACKET_CPLX, false, false> { \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_REAL& x, \
+ const PACKET_CPLX& y, \
+ const PACKET_CPLX& c) const { \
+ return padd(c, this->pmul(x, y)); \
+ } \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_REAL& x, \
+ const PACKET_CPLX& y) const { \
+ return PACKET_CPLX(Eigen::internal::pmul<PACKET_REAL>(x, y.v)); \
+ } \
+ }; \
+ \
+ template <> \
+ struct conj_helper<PACKET_CPLX, PACKET_REAL, false, false> { \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_CPLX& x, \
+ const PACKET_REAL& y, \
+ const PACKET_CPLX& c) const { \
+ return padd(c, this->pmul(x, y)); \
+ } \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_CPLX& x, \
+ const PACKET_REAL& y) const { \
+ return PACKET_CPLX(Eigen::internal::pmul<PACKET_REAL>(x.v, y)); \
+ } \
};
-#endif // EIGEN_ARCH_CONJ_HELPER_H
+namespace Eigen {
+namespace internal {
+
+template<bool Conjugate> struct conj_if;
+
+template<> struct conj_if<true> {
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return numext::conj(x); }
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T pconj(const T& x) const { return internal::pconj(x); }
+};
+
+template<> struct conj_if<false> {
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator()(const T& x) const { return x; }
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; }
+};
+
+// Generic Implementation, assume scalars since the packet-version is
+// specialized below.
+template<typename LhsType, typename RhsType, bool ConjLhs, bool ConjRhs>
+struct conj_helper {
+ typedef typename ScalarBinaryOpTraits<LhsType, RhsType>::ReturnType ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
+ { return this->pmul(x, y) + c; }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmul(const LhsType& x, const RhsType& y) const
+ { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
+};
+
+template<typename LhsScalar, typename RhsScalar>
+struct conj_helper<LhsScalar, RhsScalar, true, true> {
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmadd(const LhsScalar& x, const RhsScalar& y, const ResultType& c) const
+ { return this->pmul(x, y) + c; }
+
+ // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmul(const LhsScalar& x, const RhsScalar& y) const
+ { return numext::conj(x * y); }
+};
+
+// Implementation with equal type, use packet operations.
+template<typename Packet, bool ConjLhs, bool ConjRhs>
+struct conj_helper<Packet, Packet, ConjLhs, ConjRhs>
+{
+ typedef Packet ResultType;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const
+ { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); }
+
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const
+ { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); }
+};
+
+template<typename Packet>
+struct conj_helper<Packet, Packet, true, true>
+{
+ typedef Packet ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const
+ { return Eigen::internal::pmadd(pconj(x), pconj(y), c); }
+ // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const
+ { return pconj(Eigen::internal::pmul(x, y)); }
+};
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // EIGEN_ARCH_CONJ_HELPER_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
new file mode 100644
index 0000000..c9fbaf6
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -0,0 +1,1649 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2007 Julien Pommier
+// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com)
+// Copyright (C) 2009-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/* The exp and log functions of this file initially come from
+ * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
+ */
+
+#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
+#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
+
+namespace Eigen {
+namespace internal {
+
+// Creates a Scalar integer type with same bit-width.
+template<typename T> struct make_integer;
+template<> struct make_integer<float> { typedef numext::int32_t type; };
+template<> struct make_integer<double> { typedef numext::int64_t type; };
+template<> struct make_integer<half> { typedef numext::int16_t type; };
+template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
+
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+ enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
+ return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
+}
+
+// Safely applies frexp, correctly handles denormals.
+// Assumes IEEE floating point format.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic(const Packet& a, Packet& exponent) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
+ ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
+ const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
+ const Packet half = pset1<Packet>(Scalar(0.5));
+ const Packet zero = pzero(a);
+ const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
+
+ // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
+ const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
+ EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
+ // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
+ const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
+ const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
+ const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
+
+ // Determine exponent offset: -126 if normal, -126-24 if denormal
+ const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
+ Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
+ const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
+ exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
+
+ // Determine exponent and mantissa from normalized_a.
+ exponent = pfrexp_generic_get_biased_exponent(normalized_a);
+ // Zero, Inf and NaN return 'a' unmodified, exponent is zero
+ // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
+ const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
+ const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
+ const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
+ const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
+ exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
+ return m;
+}
+
+// Safely applies ldexp, correctly handles overflows, underflows and denormals.
+// Assumes IEEE floating point format.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pldexp_generic(const Packet& a, const Packet& exponent) {
+ // We want to return a * 2^exponent, allowing for all possible integer
+ // exponents without overflowing or underflowing in intermediate
+ // computations.
+ //
+ // Since 'a' and the output can be denormal, the maximum range of 'exponent'
+ // to consider for a float is:
+ // -255-23 -> 255+23
+ // Below -278 any finite float 'a' will become zero, and above +278 any
+ // finite float will become inf, including when 'a' is the smallest possible
+ // denormal.
+ //
+ // Unfortunately, 2^(278) cannot be represented using either one or two
+ // finite normal floats, so we must split the scale factor into at least
+ // three parts. It turns out to be faster to split 'exponent' into four
+ // factors, since [exponent>>2] is much faster to compute that [exponent/3].
+ //
+ // Set e = min(max(exponent, -278), 278);
+ // b = floor(e/4);
+ // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
+ //
+ // This will avoid any intermediate overflows and correctly handle 0, inf,
+ // NaN cases.
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<PacketI>::type ScalarI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
+ const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+ PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
+ Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
+ Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
+ out = pmul(out, c);
+ return out;
+}
+
+// Explicitly multiplies
+// a * (2^e)
+// clamping e to the range
+// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
+//
+// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
+// if 2^e doesn't fit into a normal floating-point Scalar.
+//
+// Assumes IEEE floating point format
+template<typename Packet>
+struct pldexp_fast_impl {
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<PacketI>::type ScalarI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+ Packet run(const Packet& a, const Packet& exponent) {
+ const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
+ const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
+ // restrict biased exponent between 0 and 255 for float.
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
+ // return a * (2^e)
+ return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
+ }
+};
+
+// Natural or base 2 logarithm.
+// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
+// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
+// be easily approximated by a polynomial centered on m=1 for stability.
+// TODO(gonnet): Further reduce the interval allowing for lower-degree
+// polynomial interpolants -> ... -> profit!
+template <typename Packet, bool base2>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_impl_float(const Packet _x)
+{
+ Packet x = _x;
+
+ const Packet cst_1 = pset1<Packet>(1.0f);
+ const Packet cst_neg_half = pset1<Packet>(-0.5f);
+ // The smallest non denormalized float number.
+ const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
+ const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
+ const Packet cst_pos_inf = pset1frombits<Packet>( 0x7f800000u);
+
+ // Polynomial coefficients.
+ const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
+ const Packet cst_cephes_log_p0 = pset1<Packet>(7.0376836292E-2f);
+ const Packet cst_cephes_log_p1 = pset1<Packet>(-1.1514610310E-1f);
+ const Packet cst_cephes_log_p2 = pset1<Packet>(1.1676998740E-1f);
+ const Packet cst_cephes_log_p3 = pset1<Packet>(-1.2420140846E-1f);
+ const Packet cst_cephes_log_p4 = pset1<Packet>(+1.4249322787E-1f);
+ const Packet cst_cephes_log_p5 = pset1<Packet>(-1.6668057665E-1f);
+ const Packet cst_cephes_log_p6 = pset1<Packet>(+2.0000714765E-1f);
+ const Packet cst_cephes_log_p7 = pset1<Packet>(-2.4999993993E-1f);
+ const Packet cst_cephes_log_p8 = pset1<Packet>(+3.3333331174E-1f);
+
+ // Truncate input values to the minimum positive normal.
+ x = pmax(x, cst_min_norm_pos);
+
+ Packet e;
+ // extract significant in the range [0.5,1) and exponent
+ x = pfrexp(x,e);
+
+ // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
+ // and shift by -1. The values are then centered around 0, which improves
+ // the stability of the polynomial evaluation.
+ // if( x < SQRTHF ) {
+ // e -= 1;
+ // x = x + x - 1.0;
+ // } else { x = x - 1.0; }
+ Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
+ Packet tmp = pand(x, mask);
+ x = psub(x, cst_1);
+ e = psub(e, pand(cst_1, mask));
+ x = padd(x, tmp);
+
+ Packet x2 = pmul(x, x);
+ Packet x3 = pmul(x2, x);
+
+ // Evaluate the polynomial approximant of degree 8 in three parts, probably
+ // to improve instruction-level parallelism.
+ Packet y, y1, y2;
+ y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
+ y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
+ y2 = pmadd(cst_cephes_log_p6, x, cst_cephes_log_p7);
+ y = pmadd(y, x, cst_cephes_log_p2);
+ y1 = pmadd(y1, x, cst_cephes_log_p5);
+ y2 = pmadd(y2, x, cst_cephes_log_p8);
+ y = pmadd(y, x3, y1);
+ y = pmadd(y, x3, y2);
+ y = pmul(y, x3);
+
+ y = pmadd(cst_neg_half, x2, y);
+ x = padd(x, y);
+
+ // Add the logarithm of the exponent back to the result of the interpolation.
+ if (base2) {
+ const Packet cst_log2e = pset1<Packet>(static_cast<float>(EIGEN_LOG2E));
+ x = pmadd(x, cst_log2e, e);
+ } else {
+ const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
+ x = pmadd(e, cst_ln2, x);
+ }
+
+ Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
+ Packet iszero_mask = pcmp_eq(_x,pzero(_x));
+ Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
+ // Filter out invalid inputs, i.e.:
+ // - negative arg will be NAN
+ // - 0 will be -INF
+ // - +INF will be +INF
+ return pselect(iszero_mask, cst_minus_inf,
+ por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_float(const Packet _x)
+{
+ return plog_impl_float<Packet, /* base2 */ false>(_x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_float(const Packet _x)
+{
+ return plog_impl_float<Packet, /* base2 */ true>(_x);
+}
+
+/* Returns the base e (2.718...) or base 2 logarithm of x.
+ * The argument is separated into its exponent and fractional parts.
+ * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)],
+ * is approximated by
+ *
+ * log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x).
+ *
+ * for more detail see: http://www.netlib.org/cephes/
+ */
+template <typename Packet, bool base2>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_impl_double(const Packet _x)
+{
+ Packet x = _x;
+
+ const Packet cst_1 = pset1<Packet>(1.0);
+ const Packet cst_neg_half = pset1<Packet>(-0.5);
+ // The smallest non denormalized double.
+ const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull));
+ const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull));
+ const Packet cst_pos_inf = pset1frombits<Packet>( static_cast<uint64_t>(0x7ff0000000000000ull));
+
+
+ // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x)
+ // 1/sqrt(2) <= x < sqrt(2)
+ const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
+ const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
+ const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
+ const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
+ const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
+ const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
+ const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
+
+ const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
+ const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
+ const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
+ const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
+ const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
+ const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
+
+ // Truncate input values to the minimum positive normal.
+ x = pmax(x, cst_min_norm_pos);
+
+ Packet e;
+ // extract significant in the range [0.5,1) and exponent
+ x = pfrexp(x,e);
+
+ // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
+ // and shift by -1. The values are then centered around 0, which improves
+ // the stability of the polynomial evaluation.
+ // if( x < SQRTHF ) {
+ // e -= 1;
+ // x = x + x - 1.0;
+ // } else { x = x - 1.0; }
+ Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
+ Packet tmp = pand(x, mask);
+ x = psub(x, cst_1);
+ e = psub(e, pand(cst_1, mask));
+ x = padd(x, tmp);
+
+ Packet x2 = pmul(x, x);
+ Packet x3 = pmul(x2, x);
+
+ // Evaluate the polynomial approximant , probably to improve instruction-level parallelism.
+ // y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) );
+ Packet y, y1, y_;
+ y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
+ y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
+ y = pmadd(y, x, cst_cephes_log_p2);
+ y1 = pmadd(y1, x, cst_cephes_log_p5);
+ y_ = pmadd(y, x3, y1);
+
+ y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1);
+ y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
+ y = pmadd(y, x, cst_cephes_log_q2);
+ y1 = pmadd(y1, x, cst_cephes_log_q5);
+ y = pmadd(y, x3, y1);
+
+ y_ = pmul(y_, x3);
+ y = pdiv(y_, y);
+
+ y = pmadd(cst_neg_half, x2, y);
+ x = padd(x, y);
+
+ // Add the logarithm of the exponent back to the result of the interpolation.
+ if (base2) {
+ const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
+ x = pmadd(x, cst_log2e, e);
+ } else {
+ const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
+ x = pmadd(e, cst_ln2, x);
+ }
+
+ Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
+ Packet iszero_mask = pcmp_eq(_x,pzero(_x));
+ Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
+ // Filter out invalid inputs, i.e.:
+ // - negative arg will be NAN
+ // - 0 will be -INF
+ // - +INF will be +INF
+ return pselect(iszero_mask, cst_minus_inf,
+ por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_double(const Packet _x)
+{
+ return plog_impl_double<Packet, /* base2 */ false>(_x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_double(const Packet _x)
+{
+ return plog_impl_double<Packet, /* base2 */ true>(_x);
+}
+
+/** \internal \returns log(1 + x) computed using W. Kahan's formula.
+ See: http://www.plunk.org/~hatch/rightway.php
+ */
+template<typename Packet>
+Packet generic_plog1p(const Packet& x)
+{
+ typedef typename unpacket_traits<Packet>::type ScalarType;
+ const Packet one = pset1<Packet>(ScalarType(1));
+ Packet xp1 = padd(x, one);
+ Packet small_mask = pcmp_eq(xp1, one);
+ Packet log1 = plog(xp1);
+ Packet inf_mask = pcmp_eq(xp1, log1);
+ Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
+ return pselect(por(small_mask, inf_mask), x, log_large);
+}
+
+/** \internal \returns exp(x)-1 computed using W. Kahan's formula.
+ See: http://www.plunk.org/~hatch/rightway.php
+ */
+template<typename Packet>
+Packet generic_expm1(const Packet& x)
+{
+ typedef typename unpacket_traits<Packet>::type ScalarType;
+ const Packet one = pset1<Packet>(ScalarType(1));
+ const Packet neg_one = pset1<Packet>(ScalarType(-1));
+ Packet u = pexp(x);
+ Packet one_mask = pcmp_eq(u, one);
+ Packet u_minus_one = psub(u, one);
+ Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one);
+ Packet logu = plog(u);
+ // The following comparison is to catch the case where
+ // exp(x) = +inf. It is written in this way to avoid having
+ // to form the constant +inf, which depends on the packet
+ // type.
+ Packet pos_inf_mask = pcmp_eq(logu, u);
+ Packet expm1 = pmul(u_minus_one, pdiv(x, logu));
+ expm1 = pselect(pos_inf_mask, u, expm1);
+ return pselect(one_mask,
+ x,
+ pselect(neg_one_mask,
+ neg_one,
+ expm1));
+}
+
+
+// Exponential function. Works by writing "x = m*log(2) + r" where
+// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
+// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_float(const Packet _x)
+{
+ const Packet cst_1 = pset1<Packet>(1.0f);
+ const Packet cst_half = pset1<Packet>(0.5f);
+ const Packet cst_exp_hi = pset1<Packet>( 88.723f);
+ const Packet cst_exp_lo = pset1<Packet>(-88.723f);
+
+ const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
+ const Packet cst_cephes_exp_p0 = pset1<Packet>(1.9875691500E-4f);
+ const Packet cst_cephes_exp_p1 = pset1<Packet>(1.3981999507E-3f);
+ const Packet cst_cephes_exp_p2 = pset1<Packet>(8.3334519073E-3f);
+ const Packet cst_cephes_exp_p3 = pset1<Packet>(4.1665795894E-2f);
+ const Packet cst_cephes_exp_p4 = pset1<Packet>(1.6666665459E-1f);
+ const Packet cst_cephes_exp_p5 = pset1<Packet>(5.0000001201E-1f);
+
+ // Clamp x.
+ Packet x = pmax(pmin(_x, cst_exp_hi), cst_exp_lo);
+
+ // Express exp(x) as exp(m*ln(2) + r), start by extracting
+ // m = floor(x/ln(2) + 0.5).
+ Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
+
+ // Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
+ // subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
+ // truncation errors.
+ const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
+ const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
+ Packet r = pmadd(m, cst_cephes_exp_C1, x);
+ r = pmadd(m, cst_cephes_exp_C2, r);
+
+ Packet r2 = pmul(r, r);
+ Packet r3 = pmul(r2, r);
+
+ // Evaluate the polynomial approximant,improved by instruction-level parallelism.
+ Packet y, y1, y2;
+ y = pmadd(cst_cephes_exp_p0, r, cst_cephes_exp_p1);
+ y1 = pmadd(cst_cephes_exp_p3, r, cst_cephes_exp_p4);
+ y2 = padd(r, cst_1);
+ y = pmadd(y, r, cst_cephes_exp_p2);
+ y1 = pmadd(y1, r, cst_cephes_exp_p5);
+ y = pmadd(y, r3, y1);
+ y = pmadd(y, r2, y2);
+
+ // Return 2^m * exp(r).
+ // TODO: replace pldexp with faster implementation since y in [-1, 1).
+ return pmax(pldexp(y,m), _x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_double(const Packet _x)
+{
+ Packet x = _x;
+
+ const Packet cst_1 = pset1<Packet>(1.0);
+ const Packet cst_2 = pset1<Packet>(2.0);
+ const Packet cst_half = pset1<Packet>(0.5);
+
+ const Packet cst_exp_hi = pset1<Packet>(709.784);
+ const Packet cst_exp_lo = pset1<Packet>(-709.784);
+
+ const Packet cst_cephes_LOG2EF = pset1<Packet>(1.4426950408889634073599);
+ const Packet cst_cephes_exp_p0 = pset1<Packet>(1.26177193074810590878e-4);
+ const Packet cst_cephes_exp_p1 = pset1<Packet>(3.02994407707441961300e-2);
+ const Packet cst_cephes_exp_p2 = pset1<Packet>(9.99999999999999999910e-1);
+ const Packet cst_cephes_exp_q0 = pset1<Packet>(3.00198505138664455042e-6);
+ const Packet cst_cephes_exp_q1 = pset1<Packet>(2.52448340349684104192e-3);
+ const Packet cst_cephes_exp_q2 = pset1<Packet>(2.27265548208155028766e-1);
+ const Packet cst_cephes_exp_q3 = pset1<Packet>(2.00000000000000000009e0);
+ const Packet cst_cephes_exp_C1 = pset1<Packet>(0.693145751953125);
+ const Packet cst_cephes_exp_C2 = pset1<Packet>(1.42860682030941723212e-6);
+
+ Packet tmp, fx;
+
+ // clamp x
+ x = pmax(pmin(x, cst_exp_hi), cst_exp_lo);
+ // Express exp(x) as exp(g + n*log(2)).
+ fx = pmadd(cst_cephes_LOG2EF, x, cst_half);
+
+ // Get the integer modulus of log(2), i.e. the "n" described above.
+ fx = pfloor(fx);
+
+ // Get the remainder modulo log(2), i.e. the "g" described above. Subtract
+ // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last
+ // digits right.
+ tmp = pmul(fx, cst_cephes_exp_C1);
+ Packet z = pmul(fx, cst_cephes_exp_C2);
+ x = psub(x, tmp);
+ x = psub(x, z);
+
+ Packet x2 = pmul(x, x);
+
+ // Evaluate the numerator polynomial of the rational interpolant.
+ Packet px = cst_cephes_exp_p0;
+ px = pmadd(px, x2, cst_cephes_exp_p1);
+ px = pmadd(px, x2, cst_cephes_exp_p2);
+ px = pmul(px, x);
+
+ // Evaluate the denominator polynomial of the rational interpolant.
+ Packet qx = cst_cephes_exp_q0;
+ qx = pmadd(qx, x2, cst_cephes_exp_q1);
+ qx = pmadd(qx, x2, cst_cephes_exp_q2);
+ qx = pmadd(qx, x2, cst_cephes_exp_q3);
+
+ // I don't really get this bit, copied from the SSE2 routines, so...
+ // TODO(gonnet): Figure out what is going on here, perhaps find a better
+ // rational interpolant?
+ x = pdiv(px, psub(qx, px));
+ x = pmadd(cst_2, x, cst_1);
+
+ // Construct the result 2^n * exp(g) = e * x. The max is used to catch
+ // non-finite values in the input.
+ // TODO: replace pldexp with faster implementation since x in [-1, 1).
+ return pmax(pldexp(x,fx), _x);
+}
+
+// The following code is inspired by the following stack-overflow answer:
+// https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
+// It has been largely optimized:
+// - By-pass calls to frexp.
+// - Aligned loads of required 96 bits of 2/pi. This is accomplished by
+// (1) balancing the mantissa and exponent to the required bits of 2/pi are
+// aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi.
+// - Avoid a branch in rounding and extraction of the remaining fractional part.
+// Overall, I measured a speed up higher than x2 on x86-64.
+inline float trig_reduce_huge (float xf, int *quadrant)
+{
+ using Eigen::numext::int32_t;
+ using Eigen::numext::uint32_t;
+ using Eigen::numext::int64_t;
+ using Eigen::numext::uint64_t;
+
+ const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62
+ const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point foramt
+
+ // 192 bits of 2/pi for Payne-Hanek reduction
+ // Bits are introduced by packet of 8 to enable aligned reads.
+ static const uint32_t two_over_pi [] =
+ {
+ 0x00000028, 0x000028be, 0x0028be60, 0x28be60db,
+ 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a,
+ 0x91054a7f, 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4,
+ 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770,
+ 0x4d377036, 0x377036d8, 0x7036d8a5, 0x36d8a566,
+ 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410,
+ 0x10e41000, 0xe4100000
+ };
+
+ uint32_t xi = numext::bit_cast<uint32_t>(xf);
+ // Below, -118 = -126 + 8.
+ // -126 is to get the exponent,
+ // +8 is to enable alignment of 2/pi's bits on 8 bits.
+ // This is possible because the fractional part of x as only 24 meaningful bits.
+ uint32_t e = (xi >> 23) - 118;
+ // Extract the mantissa and shift it to align it wrt the exponent
+ xi = ((xi & 0x007fffffu)| 0x00800000u) << (e & 0x7);
+
+ uint32_t i = e >> 3;
+ uint32_t twoopi_1 = two_over_pi[i-1];
+ uint32_t twoopi_2 = two_over_pi[i+3];
+ uint32_t twoopi_3 = two_over_pi[i+7];
+
+ // Compute x * 2/pi in 2.62-bit fixed-point format.
+ uint64_t p;
+ p = uint64_t(xi) * twoopi_3;
+ p = uint64_t(xi) * twoopi_2 + (p >> 32);
+ p = (uint64_t(xi * twoopi_1) << 32) + p;
+
+ // Round to nearest: add 0.5 and extract integral part.
+ uint64_t q = (p + zero_dot_five) >> 62;
+ *quadrant = int(q);
+ // Now it remains to compute "r = x - q*pi/2" with high accuracy,
+ // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as:
+ // r = (p-q)*pi/2,
+ // where the product can be be carried out with sufficient accuracy using double precision.
+ p -= q<<62;
+ return float(double(int64_t(p)) * pio2_62);
+}
+
+template<bool ComputeSine,typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+#if EIGEN_GNUC_AT_LEAST(4,4) && EIGEN_COMP_GNUC_STRICT
+__attribute__((optimize("-fno-unsafe-math-optimizations")))
+#endif
+Packet psincos_float(const Packet& _x)
+{
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+
+ const Packet cst_2oPI = pset1<Packet>(0.636619746685028076171875f); // 2/PI
+ const Packet cst_rounding_magic = pset1<Packet>(12582912); // 2^23 for rounding
+ const PacketI csti_1 = pset1<PacketI>(1);
+ const Packet cst_sign_mask = pset1frombits<Packet>(0x80000000u);
+
+ Packet x = pabs(_x);
+
+ // Scale x by 2/Pi to find x's octant.
+ Packet y = pmul(x, cst_2oPI);
+
+ // Rounding trick:
+ Packet y_round = padd(y, cst_rounding_magic);
+ EIGEN_OPTIMIZATION_BARRIER(y_round)
+ PacketI y_int = preinterpret<PacketI>(y_round); // last 23 digits represent integer (if abs(x)<2^24)
+ y = psub(y_round, cst_rounding_magic); // nearest integer to x*4/pi
+
+ // Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4
+ // using "Extended precision modular arithmetic"
+ #if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD)
+ // This version requires true FMA for high accuracy
+ // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08):
+ const float huge_th = ComputeSine ? 117435.992f : 71476.0625f;
+ x = pmadd(y, pset1<Packet>(-1.57079601287841796875f), x);
+ x = pmadd(y, pset1<Packet>(-3.1391647326017846353352069854736328125e-07f), x);
+ x = pmadd(y, pset1<Packet>(-5.390302529957764765544681040410068817436695098876953125e-15f), x);
+ #else
+ // Without true FMA, the previous set of coefficients maintain 1ULP accuracy
+ // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7.
+ // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs.
+
+ // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively.
+ // and 2 ULP up to:
+ const float huge_th = ComputeSine ? 25966.f : 18838.f;
+ x = pmadd(y, pset1<Packet>(-1.5703125), x); // = 0xbfc90000
+ EIGEN_OPTIMIZATION_BARRIER(x)
+ x = pmadd(y, pset1<Packet>(-0.000483989715576171875), x); // = 0xb9fdc000
+ EIGEN_OPTIMIZATION_BARRIER(x)
+ x = pmadd(y, pset1<Packet>(1.62865035235881805419921875e-07), x); // = 0x342ee000
+ x = pmadd(y, pset1<Packet>(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee
+
+ // For the record, the following set of coefficients maintain 2ULP up
+ // to a slightly larger range:
+ // const float huge_th = ComputeSine ? 51981.f : 39086.125f;
+ // but it slightly fails to maintain 1ULP for two values of sin below pi.
+ // x = pmadd(y, pset1<Packet>(-3.140625/2.), x);
+ // x = pmadd(y, pset1<Packet>(-0.00048351287841796875), x);
+ // x = pmadd(y, pset1<Packet>(-3.13855707645416259765625e-07), x);
+ // x = pmadd(y, pset1<Packet>(-6.0771006282767103812147979624569416046142578125e-11), x);
+
+ // For the record, with only 3 iterations it is possible to maintain
+ // 1 ULP up to 3PI (maybe more) and 2ULP up to 255.
+ // The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee
+ #endif
+
+ if(predux_any(pcmp_le(pset1<Packet>(huge_th),pabs(_x))))
+ {
+ const int PacketSize = unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize];
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize];
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) int y_int2[PacketSize];
+ pstoreu(vals, pabs(_x));
+ pstoreu(x_cpy, x);
+ pstoreu(y_int2, y_int);
+ for(int k=0; k<PacketSize;++k)
+ {
+ float val = vals[k];
+ if(val>=huge_th && (numext::isfinite)(val))
+ x_cpy[k] = trig_reduce_huge(val,&y_int2[k]);
+ }
+ x = ploadu<Packet>(x_cpy);
+ y_int = ploadu<PacketI>(y_int2);
+ }
+
+ // Compute the sign to apply to the polynomial.
+ // sin: sign = second_bit(y_int) xor signbit(_x)
+ // cos: sign = second_bit(y_int+1)
+ Packet sign_bit = ComputeSine ? pxor(_x, preinterpret<Packet>(plogical_shift_left<30>(y_int)))
+ : preinterpret<Packet>(plogical_shift_left<30>(padd(y_int,csti_1)));
+ sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
+
+ // Get the polynomial selection mask from the second bit of y_int
+ // We'll calculate both (sin and cos) polynomials and then select from the two.
+ Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(y_int, csti_1), pzero(y_int)));
+
+ Packet x2 = pmul(x,x);
+
+ // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4)
+ Packet y1 = pset1<Packet>(2.4372266125283204019069671630859375e-05f);
+ y1 = pmadd(y1, x2, pset1<Packet>(-0.00138865201734006404876708984375f ));
+ y1 = pmadd(y1, x2, pset1<Packet>(0.041666619479656219482421875f ));
+ y1 = pmadd(y1, x2, pset1<Packet>(-0.5f));
+ y1 = pmadd(y1, x2, pset1<Packet>(1.f));
+
+ // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4)
+ // octave/matlab code to compute those coefficients:
+ // x = (0:0.0001:pi/4)';
+ // A = [x.^3 x.^5 x.^7];
+ // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy
+ // c = (A'*diag(w)*A)\(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1
+ // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1))
+ //
+ Packet y2 = pset1<Packet>(-0.0001959234114083702898469196984621021329076029360294342041015625f);
+ y2 = pmadd(y2, x2, pset1<Packet>( 0.0083326873655616851693794799871284340042620897293090820312500000f));
+ y2 = pmadd(y2, x2, pset1<Packet>(-0.1666666203982298255503735617821803316473960876464843750000000000f));
+ y2 = pmul(y2, x2);
+ y2 = pmadd(y2, x, x);
+
+ // Select the correct result from the two polynomials.
+ y = ComputeSine ? pselect(poly_mask,y2,y1)
+ : pselect(poly_mask,y1,y2);
+
+ // Update the sign and filter huge inputs
+ return pxor(y, sign_bit);
+}
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psin_float(const Packet& x)
+{
+ return psincos_float<true>(x);
+}
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pcos_float(const Packet& x)
+{
+ return psincos_float<false>(x);
+}
+
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psqrt_complex(const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename Scalar::value_type RealScalar;
+ typedef typename unpacket_traits<Packet>::as_real RealPacket;
+
+ // Computes the principal sqrt of the complex numbers in the input.
+ //
+ // For example, for packets containing 2 complex numbers stored in interleaved format
+ // a = [a0, a1] = [x0, y0, x1, y1],
+ // where x0 = real(a0), y0 = imag(a0) etc., this function returns
+ // b = [b0, b1] = [u0, v0, u1, v1],
+ // such that b0^2 = a0, b1^2 = a1.
+ //
+ // To derive the formula for the complex square roots, let's consider the equation for
+ // a single complex square root of the number x + i*y. We want to find real numbers
+ // u and v such that
+ // (u + i*v)^2 = x + i*y <=>
+ // u^2 - v^2 + i*2*u*v = x + i*v.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x
+ // 2*u*v = y.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
+ // v = 0.5 * (y / u)
+ // and for x < 0,
+ // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
+ // u = 0.5 * (y / v)
+ //
+ // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
+ // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,
+
+ // In the following, without lack of generality, we have annotated the code, assuming
+ // that the input is a packet of 2 complex numbers.
+ //
+ // Step 1. Compute l = [l0, l0, l1, l1], where
+ // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2)
+ // To avoid over- and underflow, we use the stable formula for each hypotenuse
+ // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)),
+ // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1.
+
+ RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|]
+ RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|]
+ RealPacket a_max = pmax(a_abs, a_abs_flip);
+ RealPacket a_min = pmin(a_abs, a_abs_flip);
+ RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
+ RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
+ RealPacket r = pdiv(a_min, a_max);
+ const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
+ RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
+ // Set l to a_max if a_min is zero.
+ l = pselect(a_min_zero_mask, a_max, l);
+
+ // Step 2. Compute [rho0, *, rho1, *], where
+ // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|))
+ // We don't care about the imaginary parts computed here. They will be overwritten later.
+ const RealPacket cst_half = pset1<RealPacket>(RealScalar(0.5));
+ Packet rho;
+ rho.v = psqrt(pmul(cst_half, padd(a_abs, l)));
+
+ // Step 3. Compute [rho0, eta0, rho1, eta1], where
+ // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2.
+ // set eta = 0 of input is 0 + i0.
+ RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask);
+ RealPacket real_mask = peven_mask(a.v);
+ Packet positive_real_result;
+ // Compute result for inputs with positive real part.
+ positive_real_result.v = pselect(real_mask, rho.v, eta);
+
+ // Step 4. Compute solution for inputs with negative real part:
+ // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1]
+ const RealScalar neg_zero = RealScalar(numext::bit_cast<float>(0x80000000u));
+ const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), neg_zero)).v;
+ RealPacket imag_signs = pand(a.v, cst_imag_sign_mask);
+ Packet negative_real_result;
+ // Notice that rho is positive, so taking it's absolute value is a noop.
+ negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs);
+
+ // Step 5. Select solution branch based on the sign of the real parts.
+ Packet negative_real_mask;
+ negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v));
+ negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v);
+ Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result);
+
+ // Step 6. Handle special cases for infinities:
+ // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN
+ // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN
+ // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y
+ // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y
+ const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
+ Packet is_inf;
+ is_inf.v = pcmp_eq(a_abs, cst_pos_inf);
+ Packet is_real_inf;
+ is_real_inf.v = pand(is_inf.v, real_mask);
+ is_real_inf = por(is_real_inf, pcplxflip(is_real_inf));
+ // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part.
+ Packet real_inf_result;
+ real_inf_result.v = pmul(a_abs, pset1<Packet>(Scalar(RealScalar(1.0), RealScalar(0.0))).v);
+ real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v);
+ // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part.
+ Packet is_imag_inf;
+ is_imag_inf.v = pandnot(is_inf.v, real_mask);
+ is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf));
+ Packet imag_inf_result;
+ imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
+
+ return pselect(is_imag_inf, imag_inf_result,
+ pselect(is_real_inf, real_inf_result,result));
+}
+
+// TODO(rmlarsen): The following set of utilities for double word arithmetic
+// should perhaps be refactored as a separate file, since it would be generally
+// useful for special function implementation etc. Writing the algorithms in
+// terms if a double word type would also make the code more readable.
+
+// This function splits x into the nearest integer n and fractional part r,
+// such that x = n + r holds exactly.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void absolute_split(const Packet& x, Packet& n, Packet& r) {
+ n = pround(x);
+ r = psub(x, n);
+}
+
+// This function computes the sum {s, r}, such that x + y = s_hi + s_lo
+// holds exactly, and s_hi = fl(x+y), if |x| >= |y|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) {
+ s_hi = padd(x, y);
+ const Packet t = psub(s_hi, x);
+ s_lo = psub(y, t);
+}
+
+#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+// This function implements the extended precision product of
+// a pair of floating point numbers. Given {x, y}, it computes the pair
+// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
+// p_hi = fl(x * y).
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ p_hi = pmul(x, y);
+ p_lo = pmadd(x, y, pnegate(p_hi));
+}
+
+#else
+
+// This function implements the Veltkamp splitting. Given a floating point
+// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
+// exactly and that half of the significant of x fits in x_hi.
+// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
+ const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
+ const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
+ Packet rho = psub(x, gamma);
+ x_hi = padd(rho, gamma);
+ x_lo = psub(x, x_hi);
+}
+
+// This function implements Dekker's algorithm for products x * y.
+// Given floating point numbers {x, y} computes the pair
+// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
+// p_hi = fl(x * y).
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ Packet x_hi, x_lo, y_hi, y_lo;
+ veltkamp_splitting(x, x_hi, x_lo);
+ veltkamp_splitting(y, y_hi, y_lo);
+
+ p_hi = pmul(x, y);
+ p_lo = pmadd(x_hi, y_hi, pnegate(p_hi));
+ p_lo = pmadd(x_hi, y_lo, p_lo);
+ p_lo = pmadd(x_lo, y_hi, p_lo);
+ p_lo = pmadd(x_lo, y_lo, p_lo);
+}
+
+#endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+
+
+// This function implements Dekker's algorithm for the addition
+// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
+// It returns the result as a pair {s_hi, s_lo} such that
+// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly.
+// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+ void twosum(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi));
+ Packet r_hi_1, r_lo_1;
+ fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1);
+ Packet r_hi_2, r_lo_2;
+ fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2);
+ const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2);
+
+ const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo);
+ const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo);
+ const Packet s = pselect(x_greater_mask, s1, s2);
+
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This is a version of twosum for double word numbers,
+// which assumes that |x_hi| >= |y_hi|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+ void fast_twosum(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ Packet r_hi, r_lo;
+ fast_twosum(x_hi, y_hi, r_hi, r_lo);
+ const Packet s = padd(padd(y_lo, r_lo), x_lo);
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This is a version of twosum for adding a floating point number x to
+// double word number {y_hi, y_lo} number, with the assumption
+// that |x| >= |y_hi|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void fast_twosum(const Packet& x,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ Packet r_hi, r_lo;
+ fast_twosum(x, y_hi, r_hi, r_lo);
+ const Packet s = padd(y_lo, r_lo);
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This function implements the multiplication of a double word
+// number represented by {x_hi, x_lo} by a floating point number y.
+// It returns the result as a pair {p_hi, p_lo} such that
+// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error
+// of less than 2*2^{-2p}, where p is the number of significand bit
+// in the floating point type.
+// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ Packet c_hi, c_lo1;
+ twoprod(x_hi, y, c_hi, c_lo1);
+ const Packet c_lo2 = pmul(x_lo, y);
+ Packet t_hi, t_lo1;
+ fast_twosum(c_hi, c_lo2, t_hi, t_lo1);
+ const Packet t_lo2 = padd(t_lo1, c_lo1);
+ fast_twosum(t_hi, t_lo2, p_hi, p_lo);
+}
+
+// This function implements the multiplication of two double word
+// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
+// It returns the result as a pair {p_hi, p_lo} such that
+// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error
+// of less than 2*2^{-2p}, where p is the number of significand bit
+// in the floating point type.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& p_hi, Packet& p_lo) {
+ Packet p_hi_hi, p_hi_lo;
+ twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo);
+ Packet p_lo_hi, p_lo_lo;
+ twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo);
+ fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo);
+}
+
+// This function computes the reciprocal of a floating point number
+// with extra precision and returns the result as a double word.
+template <typename Packet>
+void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ // 1. Approximate the reciprocal as the reciprocal of the high order element.
+ Packet approx_recip = prsqrt(x);
+ approx_recip = pmul(approx_recip, approx_recip);
+
+ // 2. Run one step of Newton-Raphson iteration in double word arithmetic
+ // to get the bottom half. The NR iteration for reciprocal of 'a' is
+ // x_{i+1} = x_i * (2 - a * x_i)
+
+ // -a*x_i
+ Packet t1_hi, t1_lo;
+ twoprod(pnegate(x), approx_recip, t1_hi, t1_lo);
+ // 2 - a*x_i
+ Packet t2_hi, t2_lo;
+ fast_twosum(pset1<Packet>(Scalar(2)), t1_hi, t2_hi, t2_lo);
+ Packet t3_hi, t3_lo;
+ fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo);
+ // x_i * (2 - a * x_i)
+ twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo);
+}
+
+
+// This function computes log2(x) and returns the result as a double word.
+template <typename Scalar>
+struct accurate_log2 {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
+ log2_x_hi = plog2(x);
+ log2_x_lo = pzero(x);
+ }
+};
+
+// This specialization uses a more accurate algorithm to compute log2(x) for
+// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
+// This additional accuracy is needed to counter the error-magnification
+// inherent in multiplying by a potentially large exponent in pow(x,y).
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct accurate_log2<float> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
+ // The function log(1+x)/x is approximated in the interval
+ // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
+ // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
+ // where the degree 6 polynomial P(x) is evaluated in single precision,
+ // while the remaining 4 terms of Q(x), as well as the final multiplication by x
+ // to reconstruct log(1+x) are evaluated in extra precision using
+ // double word arithmetic. C0 through C3 are extra precise constants
+ // stored as double words.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 10;
+ // > f = log2(1+x)/x;
+ // > interval = [sqrt(0.5)-1;sqrt(2)-1];
+ // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
+
+ const Packet p6 = pset1<Packet>( 9.703654795885e-2f);
+ const Packet p5 = pset1<Packet>(-0.1690667718648f);
+ const Packet p4 = pset1<Packet>( 0.1720575392246f);
+ const Packet p3 = pset1<Packet>(-0.1789081543684f);
+ const Packet p2 = pset1<Packet>( 0.2050433009862f);
+ const Packet p1 = pset1<Packet>(-0.2404672354459f);
+ const Packet p0 = pset1<Packet>( 0.2885761857032f);
+
+ const Packet C3_hi = pset1<Packet>(-0.360674142838f);
+ const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
+ const Packet C2_hi = pset1<Packet>(0.480897903442f);
+ const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
+ const Packet C1_hi = pset1<Packet>(-0.721347510815f);
+ const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
+ const Packet C0_hi = pset1<Packet>(1.44269502163f);
+ const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
+ const Packet one = pset1<Packet>(1.0f);
+
+ const Packet x = psub(z, one);
+ // Evaluate P(x) in working precision.
+ // We evaluate it in multiple parts to improve instruction level
+ // parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p6, x2, p4);
+ p_even = pmadd(p_even, x2, p2);
+ p_even = pmadd(p_even, x2, p0);
+ Packet p_odd = pmadd(p5, x2, p3);
+ p_odd = pmadd(p_odd, x2, p1);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Now evaluate the low-order tems of Q(x) in double word precision.
+ // In the following, due to the alternating signs and the fact that
+ // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
+ // fast_twosum instead of the slower twosum.
+ Packet q_hi, q_lo;
+ Packet t_hi, t_lo;
+ // C3 + x * p(x)
+ twoprod(p, x, t_hi, t_lo);
+ fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
+ // C2 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
+ // C1 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
+ // C0 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
+
+ // log(z) ~= x * Q(x)
+ twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
+ }
+};
+
+// This specialization uses a more accurate algorithm to compute log2(x) for
+// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18.
+// This additional accuracy is needed to counter the error-magnification
+// inherent in multiplying by a potentially large exponent in pow(x,y).
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+
+template <>
+struct accurate_log2<double> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
+ // We use a transformation of variables:
+ // r = c * (x-1) / (x+1),
+ // such that
+ // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r).
+ // The function f(r) can be approximated well using an odd polynomial
+ // of the form
+ // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r,
+ // For the implementation of log2<double> here, Q is of degree 6 with
+ // coefficient represented in working precision (double), while C is a
+ // constant represented in extra precision as a double word to achieve
+ // full accuracy.
+ //
+ // The polynomial coefficients were computed by the Sollya script:
+ //
+ // c = 2 / log(2);
+ // trans = c * (x-1)/(x+1);
+ // itrans = (1+x/c)/(1-x/c);
+ // interval=[trans(sqrt(0.5)); trans(sqrt(2))];
+ // print(interval);
+ // f = log2(itrans(x));
+ // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating);
+ const Packet q12 = pset1<Packet>(2.87074255468000586e-9);
+ const Packet q10 = pset1<Packet>(2.38957980901884082e-8);
+ const Packet q8 = pset1<Packet>(2.31032094540014656e-7);
+ const Packet q6 = pset1<Packet>(2.27279857398537278e-6);
+ const Packet q4 = pset1<Packet>(2.31271023278625638e-5);
+ const Packet q2 = pset1<Packet>(2.47556738444535513e-4);
+ const Packet q0 = pset1<Packet>(2.88543873228900172e-3);
+ const Packet C_hi = pset1<Packet>(0.0400377511598501157);
+ const Packet C_lo = pset1<Packet>(-4.77726582251425391e-19);
+ const Packet one = pset1<Packet>(1.0);
+
+ const Packet cst_2_log2e_hi = pset1<Packet>(2.88539008177792677);
+ const Packet cst_2_log2e_lo = pset1<Packet>(4.07660016854549667e-17);
+ // c * (x - 1)
+ Packet num_hi, num_lo;
+ twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo);
+ // TODO(rmlarsen): Investigate if using the division algorithm by
+ // Muller et al. is faster/more accurate.
+ // 1 / (x + 1)
+ Packet denom_hi, denom_lo;
+ doubleword_reciprocal(padd(x, one), denom_hi, denom_lo);
+ // r = c * (x-1) / (x+1),
+ Packet r_hi, r_lo;
+ twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo);
+ // r2 = r * r
+ Packet r2_hi, r2_lo;
+ twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo);
+ // r4 = r2 * r2
+ Packet r4_hi, r4_lo;
+ twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo);
+
+ // Evaluate Q(r^2) in working precision. We evaluate it in two parts
+ // (even and odd in r^2) to improve instruction level parallelism.
+ Packet q_even = pmadd(q12, r4_hi, q8);
+ Packet q_odd = pmadd(q10, r4_hi, q6);
+ q_even = pmadd(q_even, r4_hi, q4);
+ q_odd = pmadd(q_odd, r4_hi, q2);
+ q_even = pmadd(q_even, r4_hi, q0);
+ Packet q = pmadd(q_odd, r2_hi, q_even);
+
+ // Now evaluate the low order terms of P(x) in double word precision.
+ // In the following, due to the increasing magnitude of the coefficients
+ // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead
+ // of the slower twosum.
+ // Q(r^2) * r^2
+ Packet p_hi, p_lo;
+ twoprod(r2_hi, r2_lo, q, p_hi, p_lo);
+ // Q(r^2) * r^2 + C
+ Packet p1_hi, p1_lo;
+ fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo);
+ // (Q(r^2) * r^2 + C) * r^2
+ Packet p2_hi, p2_lo;
+ twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo);
+ // ((Q(r^2) * r^2 + C) * r^2 + 1)
+ Packet p3_hi, p3_lo;
+ fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo);
+
+ // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r
+ twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo);
+ }
+};
+
+// This function computes exp2(x) (i.e. 2**x).
+template <typename Scalar>
+struct fast_accurate_exp2 {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // TODO(rmlarsen): Add a pexp2 packetop.
+ return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x));
+ }
+};
+
+// This specialization uses a faster algorithm to compute exp2(x) for floats
+// in [-0.5;0.5] with a relative accuracy of 1 ulp.
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct fast_accurate_exp2<float> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // This function approximates exp2(x) by a degree 6 polynomial of the form
+ // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
+ // single precision, and the remaining steps are evaluated with extra precision using
+ // double word arithmetic. C is an extra precise constant stored as a double word.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 6;
+ // > f = 2^x;
+ // > interval = [-0.5;0.5];
+ // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
+
+ const Packet p4 = pset1<Packet>(1.539513905e-4f);
+ const Packet p3 = pset1<Packet>(1.340007293e-3f);
+ const Packet p2 = pset1<Packet>(9.618283249e-3f);
+ const Packet p1 = pset1<Packet>(5.550328270e-2f);
+ const Packet p0 = pset1<Packet>(0.2402264923f);
+
+ const Packet C_hi = pset1<Packet>(0.6931471825f);
+ const Packet C_lo = pset1<Packet>(2.36836577e-08f);
+ const Packet one = pset1<Packet>(1.0f);
+
+ // Evaluate P(x) in working precision.
+ // We evaluate even and odd parts of the polynomial separately
+ // to gain some instruction level parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p4, x2, p2);
+ Packet p_odd = pmadd(p3, x2, p1);
+ p_even = pmadd(p_even, x2, p0);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Evaluate the remaining terms of Q(x) with extra precision using
+ // double word arithmetic.
+ Packet p_hi, p_lo;
+ // x * p(x)
+ twoprod(p, x, p_hi, p_lo);
+ // C + x * p(x)
+ Packet q1_hi, q1_lo;
+ twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
+ // x * (C + x * p(x))
+ Packet q2_hi, q2_lo;
+ twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
+ // 1 + x * (C + x * p(x))
+ Packet q3_hi, q3_lo;
+ // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
+ // for adding it to unity here.
+ fast_twosum(one, q2_hi, q3_hi, q3_lo);
+ return padd(q3_hi, padd(q2_lo, q3_lo));
+ }
+};
+
+// in [-0.5;0.5] with a relative accuracy of 1 ulp.
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct fast_accurate_exp2<double> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // This function approximates exp2(x) by a degree 10 polynomial of the form
+ // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
+ // single precision, and the remaining steps are evaluated with extra precision using
+ // double word arithmetic. C is an extra precise constant stored as a double word.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 11;
+ // > f = 2^x;
+ // > interval = [-0.5;0.5];
+ // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
+
+ const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
+ const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
+ const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
+ const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
+ const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
+ const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
+ const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
+ const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
+ const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
+ const Packet p0 = pset1<Packet>(0.240226506959101332);
+ const Packet C_hi = pset1<Packet>(0.693147180559945286);
+ const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
+ const Packet one = pset1<Packet>(1.0);
+
+ // Evaluate P(x) in working precision.
+ // We evaluate even and odd parts of the polynomial separately
+ // to gain some instruction level parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p8, x2, p6);
+ Packet p_odd = pmadd(p9, x2, p7);
+ p_even = pmadd(p_even, x2, p4);
+ p_odd = pmadd(p_odd, x2, p5);
+ p_even = pmadd(p_even, x2, p2);
+ p_odd = pmadd(p_odd, x2, p3);
+ p_even = pmadd(p_even, x2, p0);
+ p_odd = pmadd(p_odd, x2, p1);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Evaluate the remaining terms of Q(x) with extra precision using
+ // double word arithmetic.
+ Packet p_hi, p_lo;
+ // x * p(x)
+ twoprod(p, x, p_hi, p_lo);
+ // C + x * p(x)
+ Packet q1_hi, q1_lo;
+ twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
+ // x * (C + x * p(x))
+ Packet q2_hi, q2_lo;
+ twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
+ // 1 + x * (C + x * p(x))
+ Packet q3_hi, q3_lo;
+ // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
+ // for adding it to unity here.
+ fast_twosum(one, q2_hi, q3_hi, q3_lo);
+ return padd(q3_hi, padd(q2_lo, q3_lo));
+ }
+};
+
+// This function implements the non-trivial case of pow(x,y) where x is
+// positive and y is (possibly) non-integer.
+// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
+// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it
+// easier to specialize or turn off for specific types and/or backends.x
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ // Split x into exponent e_x and mantissa m_x.
+ Packet e_x;
+ Packet m_x = pfrexp(x, e_x);
+
+ // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x).
+ EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440);
+ const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half));
+ m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x);
+ e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x);
+
+ // Compute log2(m_x) with 6 extra bits of accuracy.
+ Packet rx_hi, rx_lo;
+ accurate_log2<Scalar>()(m_x, rx_hi, rx_lo);
+
+ // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled
+ // precision using double word arithmetic.
+ Packet f1_hi, f1_lo, f2_hi, f2_lo;
+ twoprod(e_x, y, f1_hi, f1_lo);
+ twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo);
+ // Sum the two terms in f using double word arithmetic. We know
+ // that |e_x| > |log2(m_x)|, except for the case where e_x==0.
+ // This means that we can use fast_twosum(f1,f2).
+ // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any
+ // accuracy by violating the assumption of fast_twosum, because
+ // it's a no-op.
+ Packet f_hi, f_lo;
+ fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo);
+
+ // Split f into integer and fractional parts.
+ Packet n_z, r_z;
+ absolute_split(f_hi, n_z, r_z);
+ r_z = padd(r_z, f_lo);
+ Packet n_r;
+ absolute_split(r_z, n_r, r_z);
+ n_z = padd(n_z, n_r);
+
+ // We now have an accurate split of f = n_z + r_z and can compute
+ // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
+ // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
+ // using a specialized algorithm. Multiplication by the second factor can
+ // be done exactly using pldexp(), since it is an integer power of 2.
+ const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
+ return pldexp(e_r, n_z);
+}
+
+// Generic implementation of pow(x,y).
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet generic_pow(const Packet& x, const Packet& y) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+
+ const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
+ const Packet cst_zero = pset1<Packet>(Scalar(0));
+ const Packet cst_one = pset1<Packet>(Scalar(1));
+ const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
+
+ const Packet abs_x = pabs(x);
+ // Predicates for sign and magnitude of x.
+ const Packet x_is_zero = pcmp_eq(x, cst_zero);
+ const Packet x_is_neg = pcmp_lt(x, cst_zero);
+ const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
+ const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
+ const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
+ const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
+ const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
+ const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
+ const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
+
+ // Predicates for sign and magnitude of y.
+ const Packet y_is_one = pcmp_eq(y, cst_one);
+ const Packet y_is_zero = pcmp_eq(y, cst_zero);
+ const Packet y_is_neg = pcmp_lt(y, cst_zero);
+ const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
+ const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
+ const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
+ EIGEN_CONSTEXPR Scalar huge_exponent =
+ (NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) /
+ NumTraits<Scalar>::epsilon();
+ const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
+
+ // Predicates for whether y is integer and/or even.
+ const Packet y_is_int = pcmp_eq(pfloor(y), y);
+ const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5)));
+ const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
+
+ // Predicates encoding special cases for the value of pow(x,y)
+ const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf),
+ y_is_int),
+ abs_y_is_inf);
+ const Packet pow_is_one = por(por(x_is_one, y_is_zero),
+ pand(x_is_neg_one,
+ por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
+ const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
+ const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos),
+ pand(abs_x_is_inf, y_is_neg)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_huge),
+ y_is_pos)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_huge),
+ y_is_neg));
+ const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg),
+ pand(abs_x_is_inf, y_is_pos)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_huge),
+ y_is_neg)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_huge),
+ y_is_pos));
+
+ // General computation of pow(x,y) for positive x or negative x and integer y.
+ const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
+ const Packet pow_abs = generic_pow_impl(abs_x, y);
+ return pselect(y_is_one, x,
+ pselect(pow_is_one, cst_one,
+ pselect(pow_is_nan, cst_nan,
+ pselect(pow_is_inf, cst_pos_inf,
+ pselect(pow_is_zero, cst_zero,
+ pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
+}
+
+
+
+/* polevl (modified for Eigen)
+ *
+ * Evaluate polynomial
+ *
+ *
+ *
+ * SYNOPSIS:
+ *
+ * int N;
+ * Scalar x, y, coef[N+1];
+ *
+ * y = polevl<decltype(x), N>( x, coef);
+ *
+ *
+ *
+ * DESCRIPTION:
+ *
+ * Evaluates polynomial of degree N:
+ *
+ * 2 N
+ * y = C + C x + C x +...+ C x
+ * 0 1 2 N
+ *
+ * Coefficients are stored in reverse order:
+ *
+ * coef[0] = C , ..., coef[N] = C .
+ * N 0
+ *
+ * The function p1evl() assumes that coef[N] = 1.0 and is
+ * omitted from the array. Its calling arguments are
+ * otherwise the same as polevl().
+ *
+ *
+ * The Eigen implementation is templatized. For best speed, store
+ * coef as a const array (constexpr), e.g.
+ *
+ * const double coef[] = {1.0, 2.0, 3.0, ...};
+ *
+ */
+template <typename Packet, int N>
+struct ppolevl {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
+ EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return pmadd(ppolevl<Packet, N-1>::run(x, coeff), x, pset1<Packet>(coeff[N]));
+ }
+};
+
+template <typename Packet>
+struct ppolevl<Packet, 0> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
+ EIGEN_UNUSED_VARIABLE(x);
+ return pset1<Packet>(coeff[0]);
+ }
+};
+
+/* chbevl (modified for Eigen)
+ *
+ * Evaluate Chebyshev series
+ *
+ *
+ *
+ * SYNOPSIS:
+ *
+ * int N;
+ * Scalar x, y, coef[N], chebevl();
+ *
+ * y = chbevl( x, coef, N );
+ *
+ *
+ *
+ * DESCRIPTION:
+ *
+ * Evaluates the series
+ *
+ * N-1
+ * - '
+ * y = > coef[i] T (x/2)
+ * - i
+ * i=0
+ *
+ * of Chebyshev polynomials Ti at argument x/2.
+ *
+ * Coefficients are stored in reverse order, i.e. the zero
+ * order term is last in the array. Note N is the number of
+ * coefficients, not the order.
+ *
+ * If coefficients are for the interval a to b, x must
+ * have been transformed to x -> 2(2x - b - a)/(b-a) before
+ * entering the routine. This maps x from (a, b) to (-1, 1),
+ * over which the Chebyshev polynomials are defined.
+ *
+ * If the coefficients are for the inverted interval, in
+ * which (a, b) is mapped to (1/b, 1/a), the transformation
+ * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity,
+ * this becomes x -> 4a/x - 1.
+ *
+ *
+ *
+ * SPEED:
+ *
+ * Taking advantage of the recurrence properties of the
+ * Chebyshev polynomials, the routine requires one more
+ * addition per loop than evaluating a nested polynomial of
+ * the same degree.
+ *
+ */
+
+template <typename Packet, int N>
+struct pchebevl {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE Packet run(Packet x, const typename unpacket_traits<Packet>::type coef[]) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ Packet b0 = pset1<Packet>(coef[0]);
+ Packet b1 = pset1<Packet>(static_cast<Scalar>(0.f));
+ Packet b2;
+
+ for (int i = 1; i < N; i++) {
+ b2 = b1;
+ b1 = b0;
+ b0 = psub(pmadd(x, b1, pset1<Packet>(coef[i])), b2);
+ }
+
+ return pmul(pset1<Packet>(static_cast<Scalar>(0.5f)), psub(b0, b2));
+ }
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
new file mode 100644
index 0000000..177a04e
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -0,0 +1,110 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2019 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
+#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
+
+namespace Eigen {
+namespace internal {
+
+// Forward declarations of the generic math functions
+// implemented in GenericPacketMathFunctions.h
+// This is needed to workaround a circular dependency.
+
+/***************************************************************************
+ * Some generic implementations to be used by implementors
+***************************************************************************/
+
+/** Default implementation of pfrexp.
+ * It is expected to be called by implementers of template<> pfrexp.
+ */
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic(const Packet& a, Packet& exponent);
+
+// Extracts the biased exponent value from Packet p, and casts the results to
+// a floating-point Packet type. Used by pfrexp_generic. Override this if
+// there is no unpacket_traits<Packet>::integer_packet.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic_get_biased_exponent(const Packet& p);
+
+/** Default implementation of pldexp.
+ * It is expected to be called by implementers of template<> pldexp.
+ */
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pldexp_generic(const Packet& a, const Packet& exponent);
+
+/** \internal \returns log(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_float(const Packet _x);
+
+/** \internal \returns log2(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_float(const Packet _x);
+
+/** \internal \returns log(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_double(const Packet _x);
+
+/** \internal \returns log2(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_double(const Packet _x);
+
+/** \internal \returns log(1 + x) */
+template<typename Packet>
+Packet generic_plog1p(const Packet& x);
+
+/** \internal \returns exp(x)-1 */
+template<typename Packet>
+Packet generic_expm1(const Packet& x);
+
+/** \internal \returns exp(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_float(const Packet _x);
+
+/** \internal \returns exp(x) for double precision real numbers */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_double(const Packet _x);
+
+/** \internal \returns sin(x) for single precision float */
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psin_float(const Packet& x);
+
+/** \internal \returns cos(x) for single precision float */
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pcos_float(const Packet& x);
+
+/** \internal \returns sqrt(x) for complex types */
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psqrt_complex(const Packet& a);
+
+template <typename Packet, int N> struct ppolevl;
+
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Half.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Half.h
new file mode 100644
index 0000000..9f8e8cc
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Half.h
@@ -0,0 +1,942 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+//
+// The conversion routines are Copyright (c) Fabian Giesen, 2016.
+// The original license follows:
+//
+// Copyright (c) Fabian Giesen, 2016
+// All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted.
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+// Standard 16-bit float type, mostly useful for GPUs. Defines a new
+// type Eigen::half (inheriting either from CUDA's or HIP's __half struct) with
+// operator overloads such that it behaves basically as an arithmetic
+// type. It will be quite slow on CPUs (so it is recommended to stay
+// in fp32 for CPUs, except for simple parameter conversions, I/O
+// to disk and the likes), but fast on GPUs.
+
+
+#ifndef EIGEN_HALF_H
+#define EIGEN_HALF_H
+
+#include <sstream>
+
+#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+// When compiling with GPU support, the "__half_raw" base class as well as
+// some other routines are defined in the GPU compiler header files
+// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
+// As a consequence, we get compile failures when compiling Eigen with
+// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
+// Eigen with GPU support
+ #pragma push_macro("EIGEN_CONSTEXPR")
+ #undef EIGEN_CONSTEXPR
+ #define EIGEN_CONSTEXPR
+#endif
+
+#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
+ template <> \
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED \
+ PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
+ return float2half(METHOD<PACKET_F>(half2float(_x))); \
+ }
+
+namespace Eigen {
+
+struct half;
+
+namespace half_impl {
+
+// We want to use the __half_raw struct from the HIP header file only during the device compile phase.
+// This is required because of a quirk in the way TensorFlow GPU builds are done.
+// When compiling TensorFlow source code with GPU support, files that
+// * contain GPU kernels (i.e. *.cu.cc files) are compiled via hipcc
+// * do not contain GPU kernels ( i.e. *.cc files) are compiled via gcc (typically)
+//
+// Tensorflow uses the Eigen::half type as its FP16 type, and there are functions that
+// * are defined in a file that gets compiled via hipcc AND
+// * have Eigen::half as a pass-by-value argument AND
+// * are called in a file that gets compiled via gcc
+//
+// In the scenario described above the caller and callee will see different versions
+// of the Eigen::half base class __half_raw, and they will be compiled by different compilers
+//
+// There appears to be an ABI mismatch between gcc and clang (which is called by hipcc) that results in
+// the callee getting corrupted values for the Eigen::half argument.
+//
+// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
+// this error, and hence the following convoluted #if condition
+#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
+// Make our own __half_raw definition that is similar to CUDA's.
+struct __half_raw {
+#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
+ // Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
+ // The element type for shared memory cannot have non-trivial constructors
+ // and hence the following special casing (which skips the zero-initilization).
+ // Note that this check gets done even in the host compilation phase, and
+ // hence the need for this
+ EIGEN_DEVICE_FUNC __half_raw() {}
+#else
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
+#endif
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {
+ }
+ __fp16 x;
+#else
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
+ numext::uint16_t x;
+#endif
+};
+
+#elif defined(EIGEN_HAS_HIP_FP16)
+ // Nothing to do here
+ // HIP fp16 header file has a definition for __half_raw
+#elif defined(EIGEN_HAS_CUDA_FP16)
+ #if EIGEN_CUDA_SDK_VER < 90000
+ // In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
+ typedef __half __half_raw;
+ #endif // defined(EIGEN_HAS_CUDA_FP16)
+#elif defined(SYCL_DEVICE_ONLY)
+ typedef cl::sycl::half __half_raw;
+#endif
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
+
+struct half_base : public __half_raw {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
+
+#if defined(EIGEN_HAS_GPU_FP16)
+ #if defined(EIGEN_HAS_HIP_FP16)
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
+ #elif defined(EIGEN_HAS_CUDA_FP16)
+ #if EIGEN_CUDA_SDK_VER >= 90000
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
+ #endif
+ #endif
+#endif
+};
+
+} // namespace half_impl
+
+// Class definition.
+struct half : public half_impl::half_base {
+
+ // Writing this out as separate #if-else blocks to make the code easier to follow
+ // The same applies to most #if-else blocks in this file
+#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
+ // Use the same base class for the following two scenarios
+ // * when compiling without GPU support enabled
+ // * during host compile phase when compiling with GPU support enabled
+ typedef half_impl::__half_raw __half_raw;
+#elif defined(EIGEN_HAS_HIP_FP16)
+ // Nothing to do here
+ // HIP fp16 header file has a definition for __half_raw
+#elif defined(EIGEN_HAS_CUDA_FP16)
+ // Note that EIGEN_CUDA_SDK_VER is set to 0 even when compiling with HIP, so
+ // (EIGEN_CUDA_SDK_VER < 90000) is true even for HIP! So keeping this within
+ // #if defined(EIGEN_HAS_CUDA_FP16) is needed
+ #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000
+ typedef half_impl::__half_raw __half_raw;
+ #endif
+#endif
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
+
+#if defined(EIGEN_HAS_GPU_FP16)
+ #if defined(EIGEN_HAS_HIP_FP16)
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
+ #elif defined(EIGEN_HAS_CUDA_FP16)
+ #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
+ #endif
+ #endif
+#endif
+
+
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
+ : half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
+ template<class T>
+ explicit EIGEN_DEVICE_FUNC half(T val)
+ : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
+ explicit EIGEN_DEVICE_FUNC half(float f)
+ : half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
+
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ template<typename RealScalar>
+ explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c)
+ : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {}
+
+ EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
+ return half_impl::half_to_float(*this);
+ }
+
+#if defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE)
+ EIGEN_DEVICE_FUNC operator __half() const {
+ ::__half_raw hr;
+ hr.x = x;
+ return __half(hr);
+ }
+#endif
+};
+
+} // end namespace Eigen
+
+namespace std {
+template<>
+struct numeric_limits<Eigen::half> {
+ static const bool is_specialized = true;
+ static const bool is_signed = true;
+ static const bool is_integer = false;
+ static const bool is_exact = false;
+ static const bool has_infinity = true;
+ static const bool has_quiet_NaN = true;
+ static const bool has_signaling_NaN = true;
+ static const float_denorm_style has_denorm = denorm_present;
+ static const bool has_denorm_loss = false;
+ static const std::float_round_style round_style = std::round_to_nearest;
+ static const bool is_iec559 = false;
+ static const bool is_bounded = false;
+ static const bool is_modulo = false;
+ static const int digits = 11;
+ static const int digits10 = 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
+ static const int max_digits10 = 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
+ static const int radix = 2;
+ static const int min_exponent = -13;
+ static const int min_exponent10 = -4;
+ static const int max_exponent = 16;
+ static const int max_exponent10 = 4;
+ static const bool traps = true;
+ static const bool tinyness_before = false;
+
+ static Eigen::half (min)() { return Eigen::half_impl::raw_uint16_to_half(0x400); }
+ static Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
+ static Eigen::half (max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
+ static Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x0800); }
+ static Eigen::half round_error() { return Eigen::half(0.5); }
+ static Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
+ static Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
+ static Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
+ static Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x1); }
+};
+
+// If std::numeric_limits<T> is specialized, should also specialize
+// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
+// std::numeric_limits<const volatile T>
+// https://stackoverflow.com/a/16519653/
+template<>
+struct numeric_limits<const Eigen::half> : numeric_limits<Eigen::half> {};
+template<>
+struct numeric_limits<volatile Eigen::half> : numeric_limits<Eigen::half> {};
+template<>
+struct numeric_limits<const volatile Eigen::half> : numeric_limits<Eigen::half> {};
+} // end namespace std
+
+namespace Eigen {
+
+namespace half_impl {
+
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && \
+ EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
+// Note: We deliberatly do *not* define this to 1 even if we have Arm's native
+// fp16 type since GPU halfs are rather different from native CPU halfs.
+// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
+#define EIGEN_HAS_NATIVE_FP16
+#endif
+
+// Intrinsics for native fp16 support. Note that on current hardware,
+// these are no faster than fp32 arithmetic (you need to use the half2
+// versions to get the ALU speed increased), but you do save the
+// conversion steps back and forth.
+
+#if defined(EIGEN_HAS_NATIVE_FP16)
+EIGEN_STRONG_INLINE __device__ half operator + (const half& a, const half& b) {
+#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
+ return __hadd(::__half(a), ::__half(b));
+#else
+ return __hadd(a, b);
+#endif
+}
+EIGEN_STRONG_INLINE __device__ half operator * (const half& a, const half& b) {
+ return __hmul(a, b);
+}
+EIGEN_STRONG_INLINE __device__ half operator - (const half& a, const half& b) {
+ return __hsub(a, b);
+}
+EIGEN_STRONG_INLINE __device__ half operator / (const half& a, const half& b) {
+#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
+ return __hdiv(a, b);
+#else
+ float num = __half2float(a);
+ float denom = __half2float(b);
+ return __float2half(num / denom);
+#endif
+}
+EIGEN_STRONG_INLINE __device__ half operator - (const half& a) {
+ return __hneg(a);
+}
+EIGEN_STRONG_INLINE __device__ half& operator += (half& a, const half& b) {
+ a = a + b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ half& operator *= (half& a, const half& b) {
+ a = a * b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ half& operator -= (half& a, const half& b) {
+ a = a - b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ half& operator /= (half& a, const half& b) {
+ a = a / b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ bool operator == (const half& a, const half& b) {
+ return __heq(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator != (const half& a, const half& b) {
+ return __hne(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator < (const half& a, const half& b) {
+ return __hlt(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator <= (const half& a, const half& b) {
+ return __hle(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator > (const half& a, const half& b) {
+ return __hgt(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) {
+ return __hge(a, b);
+}
+#endif
+
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
+ return half(vaddh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
+ return half(vmulh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
+ return half(vsubh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
+ return half(vdivh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
+ return half(vnegh_f16(a.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
+ a = half(vaddh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
+ a = half(vmulh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
+ a = half(vsubh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
+ a = half(vdivh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
+ return vceqh_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
+ return !vceqh_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
+ return vclth_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
+ return vcleh_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
+ return vcgth_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
+ return vcgeh_f16(a.x, b.x);
+}
+// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
+// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
+// of the functions, while the latter can only deal with one of them.
+#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+// We need to provide emulated *host-side* FP16 operators for clang.
+#pragma push_macro("EIGEN_DEVICE_FUNC")
+#undef EIGEN_DEVICE_FUNC
+#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
+#define EIGEN_DEVICE_FUNC __host__
+#else // both host and device need emulated ops.
+#define EIGEN_DEVICE_FUNC __host__ __device__
+#endif
+#endif
+
+// Definitions for CPUs and older HIP+CUDA, mostly working through conversion
+// to/from fp32.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
+ return half(float(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
+ return half(float(a) * float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
+ return half(float(a) - float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
+ return half(float(a) / float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
+ half result;
+ result.x = a.x ^ 0x8000;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
+ a = half(float(a) + float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
+ a = half(float(a) * float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
+ a = half(float(a) - float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
+ a = half(float(a) / float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
+ return numext::equal_strict(float(a),float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
+ return numext::not_equal_strict(float(a), float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
+ return float(a) < float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
+ return float(a) <= float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
+ return float(a) > float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
+ return float(a) >= float(b);
+}
+
+#if defined(__clang__) && defined(__CUDA__)
+#pragma pop_macro("EIGEN_DEVICE_FUNC")
+#endif
+#endif // Emulate support for half floats
+
+// Division by an index. Do it in full float precision to avoid accuracy
+// issues in converting the denominator to half.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
+ return half(static_cast<float>(a) / static_cast<float>(b));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a) {
+ a += half(1);
+ return a;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a) {
+ a -= half(1);
+ return a;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a, int) {
+ half original_value = a;
+ ++a;
+ return original_value;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
+ half original_value = a;
+ --a;
+ return original_value;
+}
+
+// Conversion routines, including fallbacks for the host or older CUDA.
+// Note that newer Intel CPUs (Haswell or newer) have vectorized versions of
+// these in hardware. If we need more performance on older/other CPUs, they are
+// also possible to vectorize directly.
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
+ // We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
+ // in the hip_fp16 header file, and that will trigger a compile error
+ // On the other hand, having anything but a return statement also triggers a compile error
+ // because this is constexpr function.
+ // Fortunately, since we need to disable EIGEN_CONSTEXPR for GPU anyway, we can get out
+ // of this catch22 by having separate bodies for GPU / non GPU
+#if defined(EIGEN_HAS_GPU_FP16)
+ __half_raw h;
+ h.x = x;
+ return h;
+#else
+ return __half_raw(x);
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const __half_raw& h) {
+ // HIP/CUDA/Default have a member 'x' of type uint16_t.
+ // For ARM64 native half, the member 'x' is of type __fp16, so we need to bit-cast.
+ // For SYCL, cl::sycl::half is _Float16, so cast directly.
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return numext::bit_cast<numext::uint16_t>(h.x);
+#elif defined(SYCL_DEVICE_ONLY)
+ return numext::bit_cast<numext::uint16_t>(h);
+#else
+ return h.x;
+#endif
+}
+
+union float32_bits {
+ unsigned int u;
+ float f;
+};
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ __half tmp_ff = __float2half(ff);
+ return *(__half_raw*)&tmp_ff;
+
+#elif defined(EIGEN_HAS_FP16_C)
+ __half_raw h;
+ h.x = _cvtss_sh(ff, 0);
+ return h;
+
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ __half_raw h;
+ h.x = static_cast<__fp16>(ff);
+ return h;
+
+#else
+ float32_bits f; f.f = ff;
+
+ const float32_bits f32infty = { 255 << 23 };
+ const float32_bits f16max = { (127 + 16) << 23 };
+ const float32_bits denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
+ unsigned int sign_mask = 0x80000000u;
+ __half_raw o;
+ o.x = static_cast<numext::uint16_t>(0x0u);
+
+ unsigned int sign = f.u & sign_mask;
+ f.u ^= sign;
+
+ // NOTE all the integer compares in this function can be safely
+ // compiled into signed compares since all operands are below
+ // 0x80000000. Important if you want fast straight SSE2 code
+ // (since there's no unsigned PCMPGTD).
+
+ if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
+ o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
+ } else { // (De)normalized number or zero
+ if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
+ // use a magic value to align our 10 mantissa bits at the bottom of
+ // the float. as long as FP addition is round-to-nearest-even this
+ // just works.
+ f.f += denorm_magic.f;
+
+ // and one integer subtract of the bias later, we have our final float!
+ o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u);
+ } else {
+ unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
+
+ // update exponent, rounding bias part 1
+ // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
+ // without arithmetic overflow.
+ f.u += 0xc8000fffU;
+ // rounding bias part 2
+ f.u += mant_odd;
+ // take the bits!
+ o.x = static_cast<numext::uint16_t>(f.u >> 13);
+ }
+ }
+
+ o.x |= static_cast<numext::uint16_t>(sign >> 16);
+ return o;
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __half2float(h);
+#elif defined(EIGEN_HAS_FP16_C)
+ return _cvtsh_ss(h.x);
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return static_cast<float>(h.x);
+#else
+ const float32_bits magic = { 113 << 23 };
+ const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
+ float32_bits o;
+
+ o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
+ unsigned int exp = shifted_exp & o.u; // just the exponent
+ o.u += (127 - 15) << 23; // exponent adjust
+
+ // handle exponent special cases
+ if (exp == shifted_exp) { // Inf/NaN?
+ o.u += (128 - 16) << 23; // extra exp adjust
+ } else if (exp == 0) { // Zero/Denormal?
+ o.u += 1 << 23; // extra exp adjust
+ o.f -= magic.f; // renormalize
+ }
+
+ o.u |= (h.x & 0x8000) << 16; // sign bit
+ return o.f;
+#endif
+}
+
+// --- standard functions ---
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) {
+#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
+ return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
+#else
+ return (a.x & 0x7fff) == 0x7c00;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __hisnan(a);
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
+#else
+ return (a.x & 0x7fff) > 0x7c00;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) {
+ return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return half(vabsh_f16(a.x));
+#else
+ half result;
+ result.x = a.x & 0x7FFF;
+ return result;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hexp(a));
+#else
+ return half(::expf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half expm1(const half& a) {
+ return half(numext::expm1(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return half(::hlog(a));
+#else
+ return half(::logf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
+ return half(numext::log1p(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
+ return half(::log10f(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log2(const half& a) {
+ return half(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hsqrt(a));
+#else
+ return half(::sqrtf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
+ return half(::powf(float(a), float(b)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
+ return half(::sinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) {
+ return half(::cosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) {
+ return half(::tanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) {
+ return half(::tanhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half asin(const half& a) {
+ return half(::asinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half acos(const half& a) {
+ return half(::acosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hfloor(a));
+#else
+ return half(::floorf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half ceil(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hceil(a));
+#else
+ return half(::ceilf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) {
+ return half(::rintf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) {
+ return half(::roundf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
+ return half(::fmodf(float(a), float(b)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (min)(const half& a, const half& b) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __hlt(b, a) ? b : a;
+#else
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f2 < f1 ? b : a;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (max)(const half& a, const half& b) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __hlt(a, b) ? b : a;
+#else
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f1 < f2 ? b : a;
+#endif
+}
+
+#ifndef EIGEN_NO_IO
+EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const half& v) {
+ os << static_cast<float>(v);
+ return os;
+}
+#endif
+
+} // end namespace half_impl
+
+// import Eigen::half_impl::half into Eigen namespace
+// using half_impl::half;
+
+namespace internal {
+
+template<>
+struct random_default_impl<half, false, false>
+{
+ static inline half run(const half& x, const half& y)
+ {
+ return x + (y-x) * half(float(std::rand()) / float(RAND_MAX));
+ }
+ static inline half run()
+ {
+ return run(half(-1.f), half(1.f));
+ }
+};
+
+template<> struct is_arithmetic<half> { enum { value = true }; };
+
+} // end namespace internal
+
+template<> struct NumTraits<Eigen::half>
+ : GenericNumTraits<Eigen::half>
+{
+ enum {
+ IsSigned = true,
+ IsInteger = false,
+ IsComplex = false,
+ RequireInitialization = false
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
+ return half_impl::raw_uint16_to_half(0x0800);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
+ return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
+ return half_impl::raw_uint16_to_half(0x7bff);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
+ return half_impl::raw_uint16_to_half(0xfbff);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
+ return half_impl::raw_uint16_to_half(0x7c00);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
+ return half_impl::raw_uint16_to_half(0x7e00);
+ }
+};
+
+} // end namespace Eigen
+
+#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ #pragma pop_macro("EIGEN_CONSTEXPR")
+#endif
+
+namespace Eigen {
+namespace numext {
+
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::half& h) {
+ return (half_impl::isnan)(h);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::half& h) {
+ return (half_impl::isinf)(h);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::half& h) {
+ return (half_impl::isfinite)(h);
+}
+
+#endif
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half bit_cast<Eigen::half, uint16_t>(const uint16_t& src) {
+ return Eigen::half(Eigen::half_impl::raw_uint16_to_half(src));
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(const Eigen::half& src) {
+ return Eigen::half_impl::raw_half_as_uint16(src);
+}
+
+} // namespace numext
+} // namespace Eigen
+
+// Add the missing shfl* intrinsics.
+// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
+// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
+//
+// HIP and CUDA prior to SDK 9.0 define
+// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
+// CUDA since 9.0 deprecates those and instead defines
+// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
+// with native support for __half and __nv_bfloat16
+//
+// Note that the following are __device__ - only functions.
+#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)) \
+ || defined(EIGEN_HIPCC)
+
+#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
+}
+
+#else // HIP or CUDA SDK < 9.0
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
+}
+
+#endif // HIP vs CUDA
+#endif // __shfl*
+
+// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
+#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)) \
+ || defined(EIGEN_HIPCC)
+EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
+ return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
+}
+#endif // __ldg
+
+#if EIGEN_HAS_STD_HASH
+namespace std {
+template <>
+struct hash<Eigen::half> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::half& a) const {
+ return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
+ }
+};
+} // end namespace std
+#endif
+
+#endif // EIGEN_HALF_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Settings.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Settings.h
index 097373c..a5c3ada 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Settings.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/Settings.h
@@ -21,7 +21,7 @@
* it does not correspond to the number of iterations or the number of instructions
*/
#ifndef EIGEN_UNROLLING_LIMIT
-#define EIGEN_UNROLLING_LIMIT 100
+#define EIGEN_UNROLLING_LIMIT 110
#endif
/** Defines the threshold between a "small" and a "large" matrix.
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/TypeCasting.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/TypeCasting.h
new file mode 100644
index 0000000..fb8183b
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/Default/TypeCasting.h
@@ -0,0 +1,120 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
+// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_GENERIC_TYPE_CASTING_H
+#define EIGEN_GENERIC_TYPE_CASTING_H
+
+namespace Eigen {
+
+namespace internal {
+
+template<>
+struct scalar_cast_op<float, Eigen::half> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __float2half(a);
+ #else
+ return Eigen::half(a);
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::half> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __float2half(static_cast<float>(a));
+ #else
+ return Eigen::half(static_cast<float>(a));
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::half, float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __half2float(a);
+ #else
+ return static_cast<float>(a);
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::half, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<float, Eigen::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
+ return Eigen::bfloat16(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
+ return Eigen::bfloat16(static_cast<float>(a));
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::bfloat16, float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
+ return static_cast<float>(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+}
+}
+
+#endif // EIGEN_GENERIC_TYPE_CASTING_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/Complex.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/Complex.h
index 306a309..f40af7f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/Complex.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/Complex.h
@@ -15,9 +15,10 @@
namespace internal {
-inline uint32x4_t p4ui_CONJ_XOR() {
+inline uint32x4_t p4ui_CONJ_XOR()
+{
// See bug 1325, clang fails to call vld1q_u64.
-#if EIGEN_COMP_CLANG
+#if EIGEN_COMP_CLANG || EIGEN_COMP_CASTXML
uint32x4_t ret = { 0x00000000, 0x80000000, 0x00000000, 0x80000000 };
return ret;
#else
@@ -26,61 +27,136 @@
#endif
}
-inline uint32x2_t p2ui_CONJ_XOR() {
+inline uint32x2_t p2ui_CONJ_XOR()
+{
static const uint32_t conj_XOR_DATA[] = { 0x00000000, 0x80000000 };
return vld1_u32( conj_XOR_DATA );
}
//---------- float ----------
+
+struct Packet1cf
+{
+ EIGEN_STRONG_INLINE Packet1cf() {}
+ EIGEN_STRONG_INLINE explicit Packet1cf(const Packet2f& a) : v(a) {}
+ Packet2f v;
+};
struct Packet2cf
{
EIGEN_STRONG_INLINE Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {}
- Packet4f v;
+ Packet4f v;
};
-template<> struct packet_traits<std::complex<float> > : default_packet_traits
+template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet2cf type;
- typedef Packet2cf half;
- enum {
+ typedef Packet1cf half;
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 2,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
- HasAdd = 1,
- HasSub = 1,
- HasMul = 1,
- HasDiv = 1,
- HasNegate = 1,
- HasAbs = 0,
- HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
HasSetLinear = 0
};
};
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
-
-template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
+template<> struct unpacket_traits<Packet1cf>
{
- float32x2_t r64;
- r64 = vld1_f32((const float *)&from);
+ typedef std::complex<float> type;
+ typedef Packet1cf half;
+ typedef Packet2f as_real;
+ enum
+ {
+ size = 1,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2cf>
+{
+ typedef std::complex<float> type;
+ typedef Packet1cf half;
+ typedef Packet4f as_real;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> EIGEN_STRONG_INLINE Packet1cf pcast<float,Packet1cf>(const float& a)
+{ return Packet1cf(vset_lane_f32(a, vdup_n_f32(0.f), 0)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pcast<Packet2f,Packet2cf>(const Packet2f& a)
+{ return Packet2cf(vreinterpretq_f32_u64(vmovl_u32(vreinterpret_u32_f32(a)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pset1<Packet1cf>(const std::complex<float>& from)
+{ return Packet1cf(vld1_f32(reinterpret_cast<const float*>(&from))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
+{
+ const float32x2_t r64 = vld1_f32(reinterpret_cast<const float*>(&from));
return Packet2cf(vcombine_f32(r64, r64));
}
-template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(padd<Packet4f>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(psub<Packet4f>(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cf padd<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(padd<Packet2f>(a.v, b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(padd<Packet4f>(a.v, b.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf psub<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(psub<Packet2f>(a.v, b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(psub<Packet4f>(a.v, b.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pnegate(const Packet1cf& a) { return Packet1cf(pnegate<Packet2f>(a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate<Packet4f>(a.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pconj(const Packet1cf& a)
+{
+ const Packet2ui b = vreinterpret_u32_f32(a.v);
+ return Packet1cf(vreinterpret_f32_u32(veor_u32(b, p2ui_CONJ_XOR())));
+}
template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a)
{
- Packet4ui b = vreinterpretq_u32_f32(a.v);
+ const Packet4ui b = vreinterpretq_u32_f32(a.v);
return Packet2cf(vreinterpretq_f32_u32(veorq_u32(b, p4ui_CONJ_XOR())));
}
+template<> EIGEN_STRONG_INLINE Packet1cf pmul<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{
+ Packet2f v1, v2;
+
+ // Get the real values of a | a1_re | a1_re |
+ v1 = vdup_lane_f32(a.v, 0);
+ // Get the imag values of a | a1_im | a1_im |
+ v2 = vdup_lane_f32(a.v, 1);
+ // Multiply the real a with b
+ v1 = vmul_f32(v1, b.v);
+ // Multiply the imag a with b
+ v2 = vmul_f32(v2, b.v);
+ // Conjugate v2
+ v2 = vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(v2), p2ui_CONJ_XOR()));
+ // Swap real/imag elements in v2.
+ v2 = vrev64_f32(v2);
+ // Add and return the result
+ return Packet1cf(vadd_f32(v1, v2));
+}
template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
Packet4f v1, v2;
@@ -93,7 +169,7 @@
v1 = vmulq_f32(v1, b.v);
// Multiply the imag a with b
v2 = vmulq_f32(v2, b.v);
- // Conjugate v2
+ // Conjugate v2
v2 = vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(v2), p4ui_CONJ_XOR()));
// Swap real/imag elements in v2.
v2 = vrev64q_f32(v2);
@@ -101,98 +177,144 @@
return Packet2cf(vaddq_f32(v1, v2));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+template<> EIGEN_STRONG_INLINE Packet1cf pcmp_eq(const Packet1cf& a, const Packet1cf& b)
{
- return Packet2cf(vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a[0])==re(b[0]), im(a[0])==im(b[0])]
+ Packet2f eq = pcmp_eq<Packet2f>(a.v, b.v);
+ // Swap real/imag elements in the mask in to get:
+ // [im(a[0])==im(b[0]), re(a[0])==re(b[0])]
+ Packet2f eq_swapped = vrev64_f32(eq);
+ // Return re(a)==re(b) && im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet1cf(pand<Packet2f>(eq, eq_swapped));
}
-template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b)
{
- return Packet2cf(vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a[0])==re(b[0]), im(a[0])==im(b[0]), re(a[1])==re(b[1]), im(a[1])==im(b[1])]
+ Packet4f eq = pcmp_eq<Packet4f>(a.v, b.v);
+ // Swap real/imag elements in the mask in to get:
+ // [im(a[0])==im(b[0]), re(a[0])==re(b[0]), im(a[1])==im(b[1]), re(a[1])==re(b[1])]
+ Packet4f eq_swapped = vrev64q_f32(eq);
+ // Return re(a)==re(b) && im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet2cf(pand<Packet4f>(eq, eq_swapped));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- return Packet2cf(vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
-}
+
+template<> EIGEN_STRONG_INLINE Packet1cf pand<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pand<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf por<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(vorr_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf por<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pxor<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pxor<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pandnot<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(vbic_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pload<Packet1cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cf(pload<Packet2f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>(reinterpret_cast<const float*>(from))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf ploadu<Packet1cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cf(ploadu<Packet2f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>(reinterpret_cast<const float*>(from))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf ploaddup<Packet1cf>(const std::complex<float>* from)
+{ return pset1<Packet1cf>(*from); }
+template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from)
+{ return pset1<Packet2cf>(*from); }
+
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> *to, const Packet1cf& from)
+{ EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> *to, const Packet2cf& from)
+{ EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast<float*>(to), from.v); }
+
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> *to, const Packet1cf& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> *to, const Packet2cf& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu(reinterpret_cast<float*>(to), from.v); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet1cf pgather<std::complex<float>, Packet1cf>(
+ const std::complex<float>* from, Index stride)
{
- return Packet2cf(vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
+ const Packet2f tmp = vdup_n_f32(std::real(from[0*stride]));
+ return Packet1cf(vset_lane_f32(std::imag(from[0*stride]), tmp, 1));
}
-
-template<> EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>((const float*)from)); }
-template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>((const float*)from)); }
-
-template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) { return pset1<Packet2cf>(*from); }
-
-template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); }
-template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); }
-
-template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(const std::complex<float>* from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(
+ const std::complex<float>* from, Index stride)
{
- Packet4f res = pset1<Packet4f>(0.f);
- res = vsetq_lane_f32(std::real(from[0*stride]), res, 0);
+ Packet4f res = vdupq_n_f32(std::real(from[0*stride]));
res = vsetq_lane_f32(std::imag(from[0*stride]), res, 1);
res = vsetq_lane_f32(std::real(from[1*stride]), res, 2);
res = vsetq_lane_f32(std::imag(from[1*stride]), res, 3);
return Packet2cf(res);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to, const Packet2cf& from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet1cf>(
+ std::complex<float>* to, const Packet1cf& from, Index stride)
+{ to[stride*0] = std::complex<float>(vget_lane_f32(from.v, 0), vget_lane_f32(from.v, 1)); }
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(
+ std::complex<float>* to, const Packet2cf& from, Index stride)
{
to[stride*0] = std::complex<float>(vgetq_lane_f32(from.v, 0), vgetq_lane_f32(from.v, 1));
to[stride*1] = std::complex<float>(vgetq_lane_f32(from.v, 2), vgetq_lane_f32(from.v, 3));
}
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> * addr) { EIGEN_ARM_PREFETCH((const float *)addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> *addr)
+{ EIGEN_ARM_PREFETCH(reinterpret_cast<const float*>(addr)); }
-template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
+template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet1cf>(const Packet1cf& a)
{
- std::complex<float> EIGEN_ALIGN16 x[2];
- vst1q_f32((float *)x, a.v);
+ EIGEN_ALIGN16 std::complex<float> x;
+ vst1_f32(reinterpret_cast<float*>(&x), a.v);
+ return x;
+}
+template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
+{
+ EIGEN_ALIGN16 std::complex<float> x[2];
+ vst1q_f32(reinterpret_cast<float*>(x), a.v);
return x[0];
}
+template<> EIGEN_STRONG_INLINE Packet1cf preverse(const Packet1cf& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a)
-{
- float32x2_t a_lo, a_hi;
- Packet4f a_r128;
+{ return Packet2cf(vcombine_f32(vget_high_f32(a.v), vget_low_f32(a.v))); }
- a_lo = vget_low_f32(a.v);
- a_hi = vget_high_f32(a.v);
- a_r128 = vcombine_f32(a_hi, a_lo);
-
- return Packet2cf(a_r128);
-}
-
+template<> EIGEN_STRONG_INLINE Packet1cf pcplxflip<Packet1cf>(const Packet1cf& a)
+{ return Packet1cf(vrev64_f32(a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& a)
-{
- return Packet2cf(vrev64q_f32(a.v));
-}
+{ return Packet2cf(vrev64q_f32(a.v)); }
+template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet1cf>(const Packet1cf& a)
+{
+ std::complex<float> s;
+ vst1_f32((float *)&s, a.v);
+ return s;
+}
template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a)
{
- float32x2_t a1, a2;
std::complex<float> s;
-
- a1 = vget_low_f32(a.v);
- a2 = vget_high_f32(a.v);
- a2 = vadd_f32(a1, a2);
- vst1_f32((float *)&s, a2);
-
+ vst1_f32(reinterpret_cast<float*>(&s), vadd_f32(vget_low_f32(a.v), vget_high_f32(a.v)));
return s;
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
+template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet1cf>(const Packet1cf& a)
{
- Packet4f sum1, sum2, sum;
-
- // Add the first two 64-bit float32x2_t of vecs[0]
- sum1 = vcombine_f32(vget_low_f32(vecs[0].v), vget_low_f32(vecs[1].v));
- sum2 = vcombine_f32(vget_high_f32(vecs[0].v), vget_high_f32(vecs[1].v));
- sum = vaddq_f32(sum1, sum2);
-
- return Packet2cf(sum);
+ std::complex<float> s;
+ vst1_f32((float *)&s, a.v);
+ return s;
}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
float32x2_t a1, a2, v1, v2, prod;
@@ -208,90 +330,67 @@
v1 = vmul_f32(v1, a2);
// Multiply the imag a with b
v2 = vmul_f32(v2, a2);
- // Conjugate v2
+ // Conjugate v2
v2 = vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(v2), p2ui_CONJ_XOR()));
// Swap real/imag elements in v2.
v2 = vrev64_f32(v2);
// Add v1, v2
prod = vadd_f32(v1, v2);
- vst1_f32((float *)&s, prod);
+ vst1_f32(reinterpret_cast<float*>(&s), prod);
return s;
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
-{
- EIGEN_STRONG_INLINE static void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset==1)
- {
- first.v = vextq_f32(first.v, second.v, 2);
- }
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cf,Packet2f)
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
+template<> EIGEN_STRONG_INLINE Packet1cf pdiv<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{
+ // TODO optimize it for NEON
+ Packet1cf res = pmul(a, pconj(b));
+ Packet2f s, rev_s;
+
+ // this computes the norm
+ s = vmul_f32(b.v, b.v);
+ rev_s = vrev64_f32(s);
+
+ return Packet1cf(pdiv<Packet2f>(res.v, vadd_f32(s, rev_s)));
+}
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for NEON
- Packet2cf res = conj_helper<Packet2cf,Packet2cf,false,true>().pmul(a,b);
+ Packet2cf res = pmul(a,pconj(b));
Packet4f s, rev_s;
// this computes the norm
s = vmulq_f32(b.v, b.v);
rev_s = vrev64q_f32(s);
- return Packet2cf(pdiv<Packet4f>(res.v, vaddq_f32(s,rev_s)));
+ return Packet2cf(pdiv<Packet4f>(res.v, vaddq_f32(s, rev_s)));
}
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet2cf,2>& kernel) {
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet1cf, 1>& /*kernel*/) {}
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2cf, 2>& kernel)
+{
Packet4f tmp = vcombine_f32(vget_high_f32(kernel.packet[0].v), vget_high_f32(kernel.packet[1].v));
kernel.packet[0].v = vcombine_f32(vget_low_f32(kernel.packet[0].v), vget_low_f32(kernel.packet[1].v));
kernel.packet[1].v = tmp;
}
+template<> EIGEN_STRONG_INLINE Packet1cf psqrt<Packet1cf>(const Packet1cf& a) {
+ return psqrt_complex<Packet1cf>(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
+ return psqrt_complex<Packet2cf>(a);
+}
+
//---------- double ----------
#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
// See bug 1325, clang fails to call vld1q_u64.
-#if EIGEN_COMP_CLANG
+#if EIGEN_COMP_CLANG || EIGEN_COMP_CASTXML
static uint64x2_t p2ul_CONJ_XOR = {0x0, 0x8000000000000000};
#else
const uint64_t p2ul_conj_XOR_DATA[] = { 0x0, 0x8000000000000000 };
@@ -309,7 +408,8 @@
{
typedef Packet1cd type;
typedef Packet1cd half;
- enum {
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 0,
size = 1,
@@ -328,24 +428,50 @@
};
};
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd>
+{
+ typedef std::complex<double> type;
+ typedef Packet1cd half;
+ typedef Packet2d as_real;
+ enum
+ {
+ size=1,
+ alignment=Aligned16,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
-template<> EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>((const double*)from)); }
-template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>((const double*)from)); }
+template<> EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>(reinterpret_cast<const double*>(from))); }
-template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
-{ /* here we really have to use unaligned loads :( */ return ploadu<Packet1cd>(&from); }
+template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>(reinterpret_cast<const double*>(from))); }
-template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(padd<Packet2d>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(psub<Packet2d>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate<Packet2d>(a.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v), p2ul_CONJ_XOR))); }
+template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
+{
+ /* here we really have to use unaligned loads :( */
+ return ploadu<Packet1cd>(&from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(padd<Packet2d>(a.v, b.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(psub<Packet2d>(a.v, b.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a)
+{ return Packet1cd(pnegate<Packet2d>(a.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a)
+{ return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v), p2ul_CONJ_XOR))); }
template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
Packet2d v1, v2;
- // Get the real values of a
+ // Get the real values of a
v1 = vdupq_lane_f64(vget_low_f64(a.v), 0);
// Get the imag values of a
v2 = vdupq_lane_f64(vget_high_f64(a.v), 0);
@@ -353,7 +479,7 @@
v1 = vmulq_f64(v1, b.v);
// Multiply the imag a with b
v2 = vmulq_f64(v2, b.v);
- // Conjugate v2
+ // Conjugate v2
v2 = vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(v2), p2ul_CONJ_XOR));
// Swap real/imag elements in v2.
v2 = preverse<Packet2d>(v2);
@@ -361,31 +487,44 @@
return Packet1cd(vaddq_f64(v1, v2));
}
-template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b)
{
- return Packet1cd(vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a)==re(b), im(a)==im(b)]
+ Packet2d eq = pcmp_eq<Packet2d>(a.v, b.v);
+ // Swap real/imag elements in the mask in to get:
+ // [im(a)==im(b), re(a)==re(b)]
+ Packet2d eq_swapped = vreinterpretq_f64_u32(vrev64q_u32(vreinterpretq_u32_f64(eq)));
+ // Return re(a)==re(b) & im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet1cd(pand<Packet2d>(eq, eq_swapped));
}
-template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- return Packet1cd(vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
-}
-template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
-}
+
+template<> EIGEN_STRONG_INLINE Packet1cd pand<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd por<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd pxor<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
+
template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- return Packet1cd(vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
-}
+{ return Packet1cd(vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
-template<> EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from) { return pset1<Packet1cd>(*from); }
+template<> EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from)
+{ return pset1<Packet1cd>(*from); }
-template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); }
-template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> *to, const Packet1cd& from)
+{ EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast<double*>(to), from.v); }
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> * addr) { EIGEN_ARM_PREFETCH((const double *)addr); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> *to, const Packet1cd& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu(reinterpret_cast<double*>(to), from.v); }
-template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index stride)
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> *addr)
+{ EIGEN_ARM_PREFETCH(reinterpret_cast<const double*>(addr)); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(
+ const std::complex<double>* from, Index stride)
{
Packet2d res = pset1<Packet2d>(0.0);
res = vsetq_lane_f64(std::real(from[0*stride]), res, 0);
@@ -393,17 +532,14 @@
return Packet1cd(res);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index stride)
-{
- to[stride*0] = std::complex<double>(vgetq_lane_f64(from.v, 0), vgetq_lane_f64(from.v, 1));
-}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(
+ std::complex<double>* to, const Packet1cd& from, Index stride)
+{ to[stride*0] = std::complex<double>(vgetq_lane_f64(from.v, 0), vgetq_lane_f64(from.v, 1)); }
-
-template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
+template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
{
- std::complex<double> EIGEN_ALIGN16 res;
+ EIGEN_ALIGN16 std::complex<double> res;
pstore<std::complex<double> >(&res, a);
-
return res;
}
@@ -411,59 +547,14 @@
template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs) { return vecs[0]; }
-
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for NEON
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
+ Packet1cd res = pmul(a,pconj(b));
Packet2d s = pmul<Packet2d>(b.v, b.v);
Packet2d rev_s = preverse<Packet2d>(s);
@@ -471,9 +562,7 @@
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)
-{
- return Packet1cd(preverse(Packet2d(x.v)));
-}
+{ return Packet1cd(preverse(Packet2d(x.v))); }
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
{
@@ -481,6 +570,11 @@
kernel.packet[0].v = vcombine_f64(vget_low_f64(kernel.packet[0].v), vget_low_f64(kernel.packet[1].v));
kernel.packet[1].v = tmp;
}
+
+template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
+ return psqrt_complex<Packet1cd>(a);
+}
+
#endif // EIGEN_ARCH_ARM64
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
new file mode 100644
index 0000000..3481f33
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
@@ -0,0 +1,183 @@
+namespace Eigen {
+namespace internal {
+
+#if EIGEN_ARCH_ARM && EIGEN_COMP_CLANG
+
+// Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm.
+// Here we specialize gebp_traits to eliminate these register spills.
+// See #2138.
+template<>
+struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
+ : gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
+{
+ EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
+ {
+ // This volatile inline ASM both acts as a barrier to prevent reordering,
+ // as well as enforces strict register use.
+ asm volatile(
+ "vmla.f32 %q[r], %q[c], %q[alpha]"
+ : [r] "+w" (r)
+ : [c] "w" (c),
+ [alpha] "w" (alpha)
+ : );
+ }
+
+ template <typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const Packet4f& a, const Packet4f& b,
+ Packet4f& c, Packet4f& tmp,
+ const LaneIdType&) const {
+ acc(a, b, c);
+ }
+
+ template <typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const Packet4f& a, const QuadPacket<Packet4f>& b,
+ Packet4f& c, Packet4f& tmp,
+ const LaneIdType& lane) const {
+ madd(a, b.get(lane), c, tmp, lane);
+ }
+};
+
+#endif // EIGEN_ARCH_ARM && EIGEN_COMP_CLANG
+
+#if EIGEN_ARCH_ARM64
+
+template<>
+struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
+ : gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
+{
+ typedef float RhsPacket;
+ typedef float32x4_t RhsPacketx4;
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ dest = vld1q_f32(b);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
+
+ EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ {
+ c = vfmaq_n_f32(c, a, b);
+ }
+
+ // NOTE: Template parameter inference failed when compiled with Android NDK:
+ // "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ { madd_helper<0>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
+ { madd_helper<1>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
+ { madd_helper<2>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
+ { madd_helper<3>(a, b, c); }
+
+ private:
+ template<int LaneID>
+ EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
+ {
+ #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
+ // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
+ // vfmaq_laneq_f32 is implemented through a costly dup
+ if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==3) asm("fmla %0.4s, %1.4s, %2.s[3]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ #else
+ c = vfmaq_laneq_f32(c, a, b, LaneID);
+ #endif
+ }
+};
+
+
+template<>
+struct gebp_traits <double,double,false,false,Architecture::NEON>
+ : gebp_traits<double,double,false,false,Architecture::Generic>
+{
+ typedef double RhsPacket;
+
+ struct RhsPacketx4 {
+ float64x2_t B_0, B_1;
+ };
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ dest.B_0 = vld1q_f64(b);
+ dest.B_1 = vld1q_f64(b+2);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
+
+ EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ {
+ c = vfmaq_n_f64(c, a, b);
+ }
+
+ // NOTE: Template parameter inference failed when compiled with Android NDK:
+ // "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ { madd_helper<0>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
+ { madd_helper<1>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
+ { madd_helper<2>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
+ { madd_helper<3>(a, b, c); }
+
+ private:
+ template <int LaneID>
+ EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
+ {
+ #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
+ // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
+ // vfmaq_laneq_f64 is implemented through a costly dup
+ if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
+ else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
+ else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
+ else if(LaneID==3) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
+ #else
+ if(LaneID==0) c = vfmaq_laneq_f64(c, a, b.B_0, 0);
+ else if(LaneID==1) c = vfmaq_laneq_f64(c, a, b.B_0, 1);
+ else if(LaneID==2) c = vfmaq_laneq_f64(c, a, b.B_1, 0);
+ else if(LaneID==3) c = vfmaq_laneq_f64(c, a, b.B_1, 1);
+ #endif
+ }
+};
+
+#endif // EIGEN_ARCH_ARM64
+
+} // namespace internal
+} // namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/MathFunctions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/MathFunctions.h
index 6bb05bb..fa6615a 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/MathFunctions.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/MathFunctions.h
@@ -5,10 +5,6 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-/* The sin, cos, exp, and log functions of this file come from
- * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
- */
-
#ifndef EIGEN_MATH_FUNCTIONS_NEON_H
#define EIGEN_MATH_FUNCTIONS_NEON_H
@@ -16,74 +12,62 @@
namespace internal {
-template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4f pexp<Packet4f>(const Packet4f& _x)
-{
- Packet4f x = _x;
- Packet4f tmp, fx;
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f pexp<Packet2f>(const Packet2f& x)
+{ return pexp_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f pexp<Packet4f>(const Packet4f& x)
+{ return pexp_float(x); }
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
- _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
- _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f plog<Packet2f>(const Packet2f& x)
+{ return plog_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f plog<Packet4f>(const Packet4f& x)
+{ return plog_float(x); }
- x = vminq_f32(x, p4f_exp_hi);
- x = vmaxq_f32(x, p4f_exp_lo);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f psin<Packet2f>(const Packet2f& x)
+{ return psin_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f psin<Packet4f>(const Packet4f& x)
+{ return psin_float(x); }
- /* express exp(x) as exp(g + n*log(2)) */
- fx = vmlaq_f32(p4f_half, x, p4f_cephes_LOG2EF);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f pcos<Packet2f>(const Packet2f& x)
+{ return pcos_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f pcos<Packet4f>(const Packet4f& x)
+{ return pcos_float(x); }
- /* perform a floorf */
- tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
+// Hyperbolic Tangent function.
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f ptanh<Packet2f>(const Packet2f& x)
+{ return internal::generic_fast_tanh_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh<Packet4f>(const Packet4f& x)
+{ return internal::generic_fast_tanh_float(x); }
- /* if greater, substract 1 */
- Packet4ui mask = vcgtq_f32(tmp, fx);
- mask = vandq_u32(mask, vreinterpretq_u32_f32(p4f_1));
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
- fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
-
- tmp = vmulq_f32(fx, p4f_cephes_exp_C1);
- Packet4f z = vmulq_f32(fx, p4f_cephes_exp_C2);
- x = vsubq_f32(x, tmp);
- x = vsubq_f32(x, z);
-
- Packet4f y = vmulq_f32(p4f_cephes_exp_p0, x);
- z = vmulq_f32(x, x);
- y = vaddq_f32(y, p4f_cephes_exp_p1);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p2);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p3);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p4);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p5);
-
- y = vmulq_f32(y, z);
- y = vaddq_f32(y, x);
- y = vaddq_f32(y, p4f_1);
-
- /* build 2^n */
- int32x4_t mm;
- mm = vcvtq_s32_f32(fx);
- mm = vaddq_s32(mm, p4i_0x7f);
- mm = vshlq_n_s32(mm, 23);
- Packet4f pow2n = vreinterpretq_f32_s32(mm);
-
- y = vmulq_f32(y, pow2n);
- return y;
+template <>
+EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) {
+ Packet4f fexponent;
+ const Packet4bf out = F32ToBf16(pfrexp<Packet4f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
}
+template <>
+EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) {
+ return F32ToBf16(pldexp<Packet4f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
+
+//---------- double ----------
+
+#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d pexp<Packet2d>(const Packet2d& x)
+{ return pexp_double(x); }
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d plog<Packet2d>(const Packet2d& x)
+{ return plog_double(x); }
+
+#endif
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/PacketMath.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/PacketMath.h
index 3d5ed0d..d2aeef4 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -24,54 +24,118 @@
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
-#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#endif
-
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
#if EIGEN_ARCH_ARM64
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#else
-#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
#endif
#endif
-#if EIGEN_COMP_MSVC
+#if EIGEN_COMP_MSVC_STRICT
// In MSVC's arm_neon.h header file, all NEON vector types
// are aliases to the same underlying type __n128.
// We thus have to wrap them to make them different C++ types.
// (See also bug 1428)
-
-template<typename T,int unique_id>
-struct eigen_packet_wrapper
-{
- operator T&() { return m_val; }
- operator const T&() const { return m_val; }
- eigen_packet_wrapper() {}
- eigen_packet_wrapper(const T &v) : m_val(v) {}
- eigen_packet_wrapper& operator=(const T &v) {
- m_val = v;
- return *this;
- }
-
- T m_val;
-};
-typedef eigen_packet_wrapper<float32x2_t,0> Packet2f;
-typedef eigen_packet_wrapper<float32x4_t,1> Packet4f;
-typedef eigen_packet_wrapper<int32x4_t ,2> Packet4i;
-typedef eigen_packet_wrapper<int32x2_t ,3> Packet2i;
-typedef eigen_packet_wrapper<uint32x4_t ,4> Packet4ui;
+typedef eigen_packet_wrapper<float32x2_t,0> Packet2f;
+typedef eigen_packet_wrapper<float32x4_t,1> Packet4f;
+typedef eigen_packet_wrapper<int32_t ,2> Packet4c;
+typedef eigen_packet_wrapper<int8x8_t ,3> Packet8c;
+typedef eigen_packet_wrapper<int8x16_t ,4> Packet16c;
+typedef eigen_packet_wrapper<uint32_t ,5> Packet4uc;
+typedef eigen_packet_wrapper<uint8x8_t ,6> Packet8uc;
+typedef eigen_packet_wrapper<uint8x16_t ,7> Packet16uc;
+typedef eigen_packet_wrapper<int16x4_t ,8> Packet4s;
+typedef eigen_packet_wrapper<int16x8_t ,9> Packet8s;
+typedef eigen_packet_wrapper<uint16x4_t ,10> Packet4us;
+typedef eigen_packet_wrapper<uint16x8_t ,11> Packet8us;
+typedef eigen_packet_wrapper<int32x2_t ,12> Packet2i;
+typedef eigen_packet_wrapper<int32x4_t ,13> Packet4i;
+typedef eigen_packet_wrapper<uint32x2_t ,14> Packet2ui;
+typedef eigen_packet_wrapper<uint32x4_t ,15> Packet4ui;
+typedef eigen_packet_wrapper<int64x2_t ,16> Packet2l;
+typedef eigen_packet_wrapper<uint64x2_t ,17> Packet2ul;
#else
-typedef float32x2_t Packet2f;
-typedef float32x4_t Packet4f;
-typedef int32x4_t Packet4i;
-typedef int32x2_t Packet2i;
-typedef uint32x4_t Packet4ui;
+typedef float32x2_t Packet2f;
+typedef float32x4_t Packet4f;
+typedef eigen_packet_wrapper<int32_t ,2> Packet4c;
+typedef int8x8_t Packet8c;
+typedef int8x16_t Packet16c;
+typedef eigen_packet_wrapper<uint32_t ,5> Packet4uc;
+typedef uint8x8_t Packet8uc;
+typedef uint8x16_t Packet16uc;
+typedef int16x4_t Packet4s;
+typedef int16x8_t Packet8s;
+typedef uint16x4_t Packet4us;
+typedef uint16x8_t Packet8us;
+typedef int32x2_t Packet2i;
+typedef int32x4_t Packet4i;
+typedef uint32x2_t Packet2ui;
+typedef uint32x4_t Packet4ui;
+typedef int64x2_t Packet2l;
+typedef uint64x2_t Packet2ul;
-#endif // EIGEN_COMP_MSVC
+#endif // EIGEN_COMP_MSVC_STRICT
+
+EIGEN_STRONG_INLINE Packet4f shuffle1(const Packet4f& m, int mask){
+ const float* a = reinterpret_cast<const float*>(&m);
+ Packet4f res = {*(a + (mask & 3)), *(a + ((mask >> 2) & 3)), *(a + ((mask >> 4) & 3 )), *(a + ((mask >> 6) & 3))};
+ return res;
+}
+
+// fuctionally equivalent to _mm_shuffle_ps in SSE when interleave
+// == false (i.e. shuffle<false>(m, n, mask) equals _mm_shuffle_ps(m, n, mask)),
+// interleave m and n when interleave == true. Currently used in LU/arch/InverseSize4.h
+// to enable a shared implementation for fast inversion of matrices of size 4.
+template<bool interleave>
+EIGEN_STRONG_INLINE Packet4f shuffle2(const Packet4f &m, const Packet4f &n, int mask)
+{
+ const float* a = reinterpret_cast<const float*>(&m);
+ const float* b = reinterpret_cast<const float*>(&n);
+ Packet4f res = {*(a + (mask & 3)), *(a + ((mask >> 2) & 3)), *(b + ((mask >> 4) & 3)), *(b + ((mask >> 6) & 3))};
+ return res;
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet4f shuffle2<true>(const Packet4f &m, const Packet4f &n, int mask)
+{
+ const float* a = reinterpret_cast<const float*>(&m);
+ const float* b = reinterpret_cast<const float*>(&n);
+ Packet4f res = {*(a + (mask & 3)), *(b + ((mask >> 2) & 3)), *(a + ((mask >> 4) & 3)), *(b + ((mask >> 6) & 3))};
+ return res;
+}
+
+EIGEN_STRONG_INLINE static int eigen_neon_shuffle_mask(int p, int q, int r, int s) {return ((s)<<6|(r)<<4|(q)<<2|(p));}
+
+EIGEN_STRONG_INLINE Packet4f vec4f_swizzle1(const Packet4f& a, int p, int q, int r, int s)
+{
+ return shuffle1(a, eigen_neon_shuffle_mask(p, q, r, s));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_swizzle2(const Packet4f& a, const Packet4f& b, int p, int q, int r, int s)
+{
+ return shuffle2<false>(a,b,eigen_neon_shuffle_mask(p, q, r, s));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<false>(a,b,eigen_neon_shuffle_mask(0, 1, 0, 1));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_movehl(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<false>(b,a,eigen_neon_shuffle_mask(2, 3, 2, 3));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpacklo(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<true>(a,b,eigen_neon_shuffle_mask(0, 0, 1, 1));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpackhi(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<true>(a,b,eigen_neon_shuffle_mask(2, 2, 3, 3));
+}
+#define vec4f_duplane(a, p) \
+ vdupq_lane_f32(vget_low_f32(a), p)
#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
const Packet4f p4f_##NAME = pset1<Packet4f>(X)
@@ -98,81 +162,816 @@
#define EIGEN_ARM_PREFETCH(ADDR)
#endif
-template<> struct packet_traits<float> : default_packet_traits
+template <>
+struct packet_traits<float> : default_packet_traits
{
typedef Packet4f type;
- typedef Packet4f half; // Packet2f intrinsics not implemented yet
- enum {
+ typedef Packet2f half;
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 4,
- HasHalfPacket=0, // Packet2f intrinsics not implemented yet
-
- HasDiv = 1,
- // FIXME check the Has*
- HasSin = 0,
- HasCos = 0,
- HasLog = 0,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
HasExp = 1,
- HasSqrt = 0
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBessel = 0, // Issues with accuracy.
+ HasNdtri = 0
};
};
-template<> struct packet_traits<int32_t> : default_packet_traits
+
+template <>
+struct packet_traits<int8_t> : default_packet_traits
{
- typedef Packet4i type;
- typedef Packet4i half; // Packet2i intrinsics not implemented yet
- enum {
+ typedef Packet16c type;
+ typedef Packet8c half;
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
- HasHalfPacket=0 // Packet2i intrinsics not implemented yet
- // FIXME check the Has*
+ size = 16,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
};
};
-#if EIGEN_GNUC_AT_MOST(4,4) && !EIGEN_COMP_LLVM
-// workaround gcc 4.2, 4.3 and 4.4 compilatin issue
+template <>
+struct packet_traits<uint8_t> : default_packet_traits
+{
+ typedef Packet16uc type;
+ typedef Packet8uc half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 1,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct packet_traits<int16_t> : default_packet_traits
+{
+ typedef Packet8s type;
+ typedef Packet4s half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint16_t> : default_packet_traits
+{
+ typedef Packet8us type;
+ typedef Packet4us half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct packet_traits<int32_t> : default_packet_traits
+{
+ typedef Packet4i type;
+ typedef Packet2i half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint32_t> : default_packet_traits
+{
+ typedef Packet4ui type;
+ typedef Packet2ui half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct packet_traits<int64_t> : default_packet_traits
+{
+ typedef Packet2l type;
+ typedef Packet2l half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 2,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint64_t> : default_packet_traits
+{
+ typedef Packet2ul type;
+ typedef Packet2ul half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 2,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+#if EIGEN_GNUC_AT_MOST(4, 4) && !EIGEN_COMP_LLVM
+// workaround gcc 4.2, 4.3 and 4.4 compilation issue
EIGEN_STRONG_INLINE float32x4_t vld1q_f32(const float* x) { return ::vld1q_f32((const float32_t*)x); }
-EIGEN_STRONG_INLINE float32x2_t vld1_f32 (const float* x) { return ::vld1_f32 ((const float32_t*)x); }
-EIGEN_STRONG_INLINE float32x2_t vld1_dup_f32 (const float* x) { return ::vld1_dup_f32 ((const float32_t*)x); }
-EIGEN_STRONG_INLINE void vst1q_f32(float* to, float32x4_t from) { ::vst1q_f32((float32_t*)to,from); }
-EIGEN_STRONG_INLINE void vst1_f32 (float* to, float32x2_t from) { ::vst1_f32 ((float32_t*)to,from); }
+EIGEN_STRONG_INLINE float32x2_t vld1_f32(const float* x) { return ::vld1_f32 ((const float32_t*)x); }
+EIGEN_STRONG_INLINE float32x2_t vld1_dup_f32(const float* x) { return ::vld1_dup_f32 ((const float32_t*)x); }
+EIGEN_STRONG_INLINE void vst1q_f32(float* to, float32x4_t from) { ::vst1q_f32((float32_t*)to,from); }
+EIGEN_STRONG_INLINE void vst1_f32 (float* to, float32x2_t from) { ::vst1_f32 ((float32_t*)to,from); }
#endif
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet4i> { typedef int32_t type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
+template<> struct unpacket_traits<Packet2f>
+{
+ typedef float type;
+ typedef Packet2f half;
+ typedef Packet2i integer_packet;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4f>
+{
+ typedef float type;
+ typedef Packet2f half;
+ typedef Packet4i integer_packet;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4c>
+{
+ typedef int8_t type;
+ typedef Packet4c half;
+ enum
+ {
+ size = 4,
+ alignment = Unaligned,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8c>
+{
+ typedef int8_t type;
+ typedef Packet4c half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet16c>
+{
+ typedef int8_t type;
+ typedef Packet8c half;
+ enum
+ {
+ size = 16,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4uc>
+{
+ typedef uint8_t type;
+ typedef Packet4uc half;
+ enum
+ {
+ size = 4,
+ alignment = Unaligned,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8uc>
+{
+ typedef uint8_t type;
+ typedef Packet4uc half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet16uc>
+{
+ typedef uint8_t type;
+ typedef Packet8uc half;
+ enum
+ {
+ size = 16,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false};
+};
+template<> struct unpacket_traits<Packet4s>
+{
+ typedef int16_t type;
+ typedef Packet4s half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8s>
+{
+ typedef int16_t type;
+ typedef Packet4s half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4us>
+{
+ typedef uint16_t type;
+ typedef Packet4us half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8us>
+{
+ typedef uint16_t type;
+ typedef Packet4us half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2i>
+{
+ typedef int32_t type;
+ typedef Packet2i half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4i>
+{
+ typedef int32_t type;
+ typedef Packet2i half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2ui>
+{
+ typedef uint32_t type;
+ typedef Packet2ui half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4ui>
+{
+ typedef uint32_t type;
+ typedef Packet2ui half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2l>
+{
+ typedef int64_t type;
+ typedef Packet2l half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2ul>
+{
+ typedef uint64_t type;
+ typedef Packet2ul half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
-template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) { return vdupq_n_f32(from); }
-template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int32_t& from) { return vdupq_n_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2f pset1<Packet2f>(const float& from) { return vdup_n_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) { return vdupq_n_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c pset1<Packet4c>(const int8_t& from)
+{ return vget_lane_s32(vreinterpret_s32_s8(vdup_n_s8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c pset1<Packet8c>(const int8_t& from) { return vdup_n_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet16c pset1<Packet16c>(const int8_t& from) { return vdupq_n_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet4uc pset1<Packet4uc>(const uint8_t& from)
+{ return vget_lane_u32(vreinterpret_u32_u8(vdup_n_u8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc pset1<Packet8uc>(const uint8_t& from) { return vdup_n_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet16uc pset1<Packet16uc>(const uint8_t& from) { return vdupq_n_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet4s pset1<Packet4s>(const int16_t& from) { return vdup_n_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const int16_t& from) { return vdupq_n_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet4us pset1<Packet4us>(const uint16_t& from) { return vdup_n_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet8us pset1<Packet8us>(const uint16_t& from) { return vdupq_n_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet2i pset1<Packet2i>(const int32_t& from) { return vdup_n_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int32_t& from) { return vdupq_n_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2ui pset1<Packet2ui>(const uint32_t& from) { return vdup_n_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui pset1<Packet4ui>(const uint32_t& from) { return vdupq_n_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet2l pset1<Packet2l>(const int64_t& from) { return vdupq_n_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul pset1<Packet2ul>(const uint64_t& from) { return vdupq_n_u64(from); }
+template<> EIGEN_STRONG_INLINE Packet2f pset1frombits<Packet2f>(unsigned int from)
+{ return vreinterpret_f32_u32(vdup_n_u32(from)); }
+template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from)
+{ return vreinterpretq_f32_u32(vdupq_n_u32(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet2f plset<Packet2f>(const float& a)
+{
+ const float c[] = {0.0f,1.0f};
+ return vadd_f32(pset1<Packet2f>(a), vld1_f32(c));
+}
template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a)
{
- const float f[] = {0, 1, 2, 3};
- Packet4f countdown = vld1q_f32(f);
- return vaddq_f32(pset1<Packet4f>(a), countdown);
+ const float c[] = {0.0f,1.0f,2.0f,3.0f};
+ return vaddq_f32(pset1<Packet4f>(a), vld1q_f32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4c plset<Packet4c>(const int8_t& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(vreinterpret_s8_u32(vdup_n_u32(0x03020100)), vdup_n_s8(a))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c plset<Packet8c>(const int8_t& a)
+{
+ const int8_t c[] = {0,1,2,3,4,5,6,7};
+ return vadd_s8(pset1<Packet8c>(a), vld1_s8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet16c plset<Packet16c>(const int8_t& a)
+{
+ const int8_t c[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
+ return vaddq_s8(pset1<Packet16c>(a), vld1q_s8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4uc plset<Packet4uc>(const uint8_t& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(vreinterpret_u8_u32(vdup_n_u32(0x03020100)), vdup_n_u8(a))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc plset<Packet8uc>(const uint8_t& a)
+{
+ const uint8_t c[] = {0,1,2,3,4,5,6,7};
+ return vadd_u8(pset1<Packet8uc>(a), vld1_u8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const uint8_t& a)
+{
+ const uint8_t c[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
+ return vaddq_u8(pset1<Packet16uc>(a), vld1q_u8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4s plset<Packet4s>(const int16_t& a)
+{
+ const int16_t c[] = {0,1,2,3};
+ return vadd_s16(pset1<Packet4s>(a), vld1_s16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4us plset<Packet4us>(const uint16_t& a)
+{
+ const uint16_t c[] = {0,1,2,3};
+ return vadd_u16(pset1<Packet4us>(a), vld1_u16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet8s plset<Packet8s>(const int16_t& a)
+{
+ const int16_t c[] = {0,1,2,3,4,5,6,7};
+ return vaddq_s16(pset1<Packet8s>(a), vld1q_s16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet8us plset<Packet8us>(const uint16_t& a)
+{
+ const uint16_t c[] = {0,1,2,3,4,5,6,7};
+ return vaddq_u16(pset1<Packet8us>(a), vld1q_u16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2i plset<Packet2i>(const int32_t& a)
+{
+ const int32_t c[] = {0,1};
+ return vadd_s32(pset1<Packet2i>(a), vld1_s32(c));
}
template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int32_t& a)
{
- const int32_t i[] = {0, 1, 2, 3};
- Packet4i countdown = vld1q_s32(i);
- return vaddq_s32(pset1<Packet4i>(a), countdown);
+ const int32_t c[] = {0,1,2,3};
+ return vaddq_s32(pset1<Packet4i>(a), vld1q_s32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2ui plset<Packet2ui>(const uint32_t& a)
+{
+ const uint32_t c[] = {0,1};
+ return vadd_u32(pset1<Packet2ui>(a), vld1_u32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4ui plset<Packet4ui>(const uint32_t& a)
+{
+ const uint32_t c[] = {0,1,2,3};
+ return vaddq_u32(pset1<Packet4ui>(a), vld1q_u32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2l plset<Packet2l>(const int64_t& a)
+{
+ const int64_t c[] = {0,1};
+ return vaddq_s64(pset1<Packet2l>(a), vld1q_s64(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul plset<Packet2ul>(const uint64_t& a)
+{
+ const uint64_t c[] = {0,1};
+ return vaddq_u64(pset1<Packet2ul>(a), vld1q_u64(c));
}
+template<> EIGEN_STRONG_INLINE Packet2f padd<Packet2f>(const Packet2f& a, const Packet2f& b) { return vadd_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) { return vaddq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c padd<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c padd<Packet8c>(const Packet8c& a, const Packet8c& b) { return vadd_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c>(const Packet16c& a, const Packet16c& b) { return vaddq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc padd<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc padd<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vadd_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc padd<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vaddq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s padd<Packet4s>(const Packet4s& a, const Packet4s& b) { return vadd_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s>(const Packet8s& a, const Packet8s& b) { return vaddq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us padd<Packet4us>(const Packet4us& a, const Packet4us& b) { return vadd_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us>(const Packet8us& a, const Packet8us& b) { return vaddq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i padd<Packet2i>(const Packet2i& a, const Packet2i& b) { return vadd_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return vaddq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui padd<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vadd_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vaddq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l padd<Packet2l>(const Packet2l& a, const Packet2l& b) { return vaddq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul padd<Packet2ul>(const Packet2ul& a, const Packet2ul& b) { return vaddq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f psub<Packet2f>(const Packet2f& a, const Packet2f& b) { return vsub_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return vsubq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c psub<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vsub_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c psub<Packet8c>(const Packet8c& a, const Packet8c& b) { return vsub_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c psub<Packet16c>(const Packet16c& a, const Packet16c& b) { return vsubq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc psub<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vsub_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc psub<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vsub_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc psub<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vsubq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s psub<Packet4s>(const Packet4s& a, const Packet4s& b) { return vsub_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s psub<Packet8s>(const Packet8s& a, const Packet8s& b) { return vsubq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us psub<Packet4us>(const Packet4us& a, const Packet4us& b) { return vsub_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us psub<Packet8us>(const Packet8us& a, const Packet8us& b) { return vsubq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i psub<Packet2i>(const Packet2i& a, const Packet2i& b) { return vsub_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return vsubq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui psub<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vsub_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui psub<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vsubq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l psub<Packet2l>(const Packet2l& a, const Packet2l& b) { return vsubq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul psub<Packet2ul>(const Packet2ul& a, const Packet2ul& b) { return vsubq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pxor<Packet2f>(const Packet2f& a, const Packet2f& b);
+template<> EIGEN_STRONG_INLINE Packet2f paddsub<Packet2f>(const Packet2f& a, const Packet2f & b) {
+ Packet2f mask = {numext::bit_cast<float>(0x80000000u), 0.0f};
+ return padd(a, pxor(mask, b));
+}
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b);
+template<> EIGEN_STRONG_INLINE Packet4f paddsub<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ Packet4f mask = {numext::bit_cast<float>(0x80000000u), 0.0f, numext::bit_cast<float>(0x80000000u), 0.0f};
+ return padd(a, pxor(mask, b));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pnegate(const Packet2f& a) { return vneg_f32(a); }
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return vnegq_f32(a); }
+template<> EIGEN_STRONG_INLINE Packet4c pnegate(const Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vneg_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c pnegate(const Packet8c& a) { return vneg_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet16c pnegate(const Packet16c& a) { return vnegq_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet4s pnegate(const Packet4s& a) { return vneg_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) { return vnegq_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet2i pnegate(const Packet2i& a) { return vneg_s32(a); }
template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return vnegq_s32(a); }
+template<> EIGEN_STRONG_INLINE Packet2l pnegate(const Packet2l& a) {
+#if EIGEN_ARCH_ARM64
+ return vnegq_s64(a);
+#else
+ return vcombine_s64(
+ vdup_n_s64(-vgetq_lane_s64(a, 0)),
+ vdup_n_s64(-vgetq_lane_s64(a, 1)));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2f pconj(const Packet2f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4c pconj(const Packet4c& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8c pconj(const Packet8c& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16c pconj(const Packet16c& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4uc pconj(const Packet4uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8uc pconj(const Packet8uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16uc pconj(const Packet16uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4s pconj(const Packet4s& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8s pconj(const Packet8s& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4us pconj(const Packet4us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8us pconj(const Packet8us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2i pconj(const Packet2i& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2ui pconj(const Packet2ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4ui pconj(const Packet4ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2l pconj(const Packet2l& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2ul pconj(const Packet2ul& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2f pmul<Packet2f>(const Packet2f& a, const Packet2f& b) { return vmul_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b) { return vmulq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c pmul<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmul_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmul<Packet8c>(const Packet8c& a, const Packet8c& b) { return vmul_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c>(const Packet16c& a, const Packet16c& b) { return vmulq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmul<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmul_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmul<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vmul_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vmulq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmul<Packet4s>(const Packet4s& a, const Packet4s& b) { return vmul_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmul<Packet8s>(const Packet8s& a, const Packet8s& b) { return vmulq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmul<Packet4us>(const Packet4us& a, const Packet4us& b) { return vmul_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmul<Packet8us>(const Packet8us& a, const Packet8us& b) { return vmulq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmul<Packet2i>(const Packet2i& a, const Packet2i& b) { return vmul_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) { return vmulq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmul<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vmul_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmul<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vmulq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pmul<Packet2l>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0)*vgetq_lane_s64(b, 0)),
+ vdup_n_s64(vgetq_lane_s64(a, 1)*vgetq_lane_s64(b, 1)));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pmul<Packet2ul>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0)*vgetq_lane_u64(b, 0)),
+ vdup_n_u64(vgetq_lane_u64(a, 1)*vgetq_lane_u64(b, 1)));
+}
+template<> EIGEN_STRONG_INLINE Packet2f pdiv<Packet2f>(const Packet2f& a, const Packet2f& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vdiv_f32(a,b);
+#else
+ Packet2f inv, restep, div;
+
+ // NEON does not offer a divide instruction, we have to do a reciprocal approximation
+ // However NEON in contrast to other SIMD engines (AltiVec/SSE), offers
+ // a reciprocal estimate AND a reciprocal step -which saves a few instructions
+ // vrecpeq_f32() returns an estimate to 1/b, which we will finetune with
+ // Newton-Raphson and vrecpsq_f32()
+ inv = vrecpe_f32(b);
+
+ // This returns a differential, by which we will have to multiply inv to get a better
+ // approximation of 1/b.
+ restep = vrecps_f32(b, inv);
+ inv = vmul_f32(restep, inv);
+
+ // Finally, multiply a by 1/b and get the wanted result of the division.
+ div = vmul_f32(a, inv);
+
+ return div;
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
{
#if EIGEN_ARCH_ARM64
@@ -199,357 +998,2629 @@
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4c pdiv<Packet4c>(const Packet4c& /*a*/, const Packet4c& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4c>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pdiv<Packet8c>(const Packet8c& /*a*/, const Packet8c& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8c>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet16c pdiv<Packet16c>(const Packet16c& /*a*/, const Packet16c& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet16c>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4uc pdiv<Packet4uc>(const Packet4uc& /*a*/, const Packet4uc& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4uc>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pdiv<Packet8uc>(const Packet8uc& /*a*/, const Packet8uc& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8uc>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc pdiv<Packet16uc>(const Packet16uc& /*a*/, const Packet16uc& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet16uc>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4s pdiv<Packet4s>(const Packet4s& /*a*/, const Packet4s& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4s>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8s pdiv<Packet8s>(const Packet8s& /*a*/, const Packet8s& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8s>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4us pdiv<Packet4us>(const Packet4us& /*a*/, const Packet4us& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4us>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8us pdiv<Packet8us>(const Packet8us& /*a*/, const Packet8us& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8us>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet2i pdiv<Packet2i>(const Packet2i& /*a*/, const Packet2i& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2i>(0);
+}
template<> EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& /*a*/, const Packet4i& /*b*/)
-{ eigen_assert(false && "packet integer division are not supported by NEON");
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
return pset1<Packet4i>(0);
}
+template<> EIGEN_STRONG_INLINE Packet2ui pdiv<Packet2ui>(const Packet2ui& /*a*/, const Packet2ui& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2ui>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4ui pdiv<Packet4ui>(const Packet4ui& /*a*/, const Packet4ui& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4ui>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet2l pdiv<Packet2l>(const Packet2l& /*a*/, const Packet2l& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2l>(0LL);
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pdiv<Packet2ul>(const Packet2ul& /*a*/, const Packet2ul& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2ul>(0ULL);
+}
-// Clang/ARM wrongly advertises __ARM_FEATURE_FMA even when it's not available,
-// then implements a slow software scalar fallback calling fmaf()!
-// Filed LLVM bug:
-// https://llvm.org/bugs/show_bug.cgi?id=27216
-#if (defined __ARM_FEATURE_FMA) && !(EIGEN_COMP_CLANG && EIGEN_ARCH_ARM)
-// See bug 936.
-// FMA is available on VFPv4 i.e. when compiling with -mfpu=neon-vfpv4.
-// FMA is a true fused multiply-add i.e. only 1 rounding at the end, no intermediate rounding.
-// MLA is not fused i.e. does 2 roundings.
-// In addition to giving better accuracy, FMA also gives better performance here on a Krait (Nexus 4):
-// MLA: 10 GFlop/s ; FMA: 12 GFlops/s.
-template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vfmaq_f32(c,a,b); }
+
+#ifdef __ARM_FEATURE_FMA
+template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
+{ return vfmaq_f32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c)
+{ return vfma_f32(c,a,b); }
#else
-template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
-#if EIGEN_COMP_CLANG && EIGEN_ARCH_ARM
- // Clang/ARM will replace VMLA by VMUL+VADD at least for some values of -mcpu,
- // at least -mcpu=cortex-a8 and -mcpu=cortex-a7. Since the former is the default on
- // -march=armv7-a, that is a very common case.
- // See e.g. this thread:
- // http://lists.llvm.org/pipermail/llvm-dev/2013-December/068806.html
- // Filed LLVM bug:
- // https://llvm.org/bugs/show_bug.cgi?id=27219
- Packet4f r = c;
- asm volatile(
- "vmla.f32 %q[r], %q[a], %q[b]"
- : [r] "+w" (r)
- : [a] "w" (a),
- [b] "w" (b)
- : );
- return r;
-#else
+template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
+{
return vmlaq_f32(c,a,b);
-#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c)
+{
+ return vmla_f32(c,a,b);
}
#endif
// No FMA instruction for int, so use MLA unconditionally.
-template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return vmlaq_s32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c pmadd(const Packet4c& a, const Packet4c& b, const Packet4c& c)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmla_s8(
+ vreinterpret_s8_s32(vdup_n_s32(c)),
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmadd(const Packet8c& a, const Packet8c& b, const Packet8c& c)
+{ return vmla_s8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmadd(const Packet16c& a, const Packet16c& b, const Packet16c& c)
+{ return vmlaq_s8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmadd(const Packet4uc& a, const Packet4uc& b, const Packet4uc& c)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmla_u8(
+ vreinterpret_u8_u32(vdup_n_u32(c)),
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmadd(const Packet8uc& a, const Packet8uc& b, const Packet8uc& c)
+{ return vmla_u8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmadd(const Packet16uc& a, const Packet16uc& b, const Packet16uc& c)
+{ return vmlaq_u8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmadd(const Packet4s& a, const Packet4s& b, const Packet4s& c)
+{ return vmla_s16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c)
+{ return vmlaq_s16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmadd(const Packet4us& a, const Packet4us& b, const Packet4us& c)
+{ return vmla_u16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c)
+{ return vmlaq_u16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmadd(const Packet2i& a, const Packet2i& b, const Packet2i& c)
+{ return vmla_s32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c)
+{ return vmlaq_s32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmadd(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c)
+{ return vmla_u32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmadd(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c)
+{ return vmlaq_u32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pabsdiff<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vabd_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pabsdiff<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vabdq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c pabsdiff<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vabd_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pabsdiff<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vabd_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pabsdiff<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vabdq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pabsdiff<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vabd_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pabsdiff<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vabd_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pabsdiff<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vabdq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pabsdiff<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vabd_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pabsdiff<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vabdq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pabsdiff<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vabd_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pabsdiff<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vabdq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pabsdiff<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vabd_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pabsdiff<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vabdq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pabsdiff<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vabd_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pabsdiff<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vabdq_u32(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pmin<Packet2f>(const Packet2f& a, const Packet2f& b) { return vmin_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) { return vminq_f32(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vminq_s32(a,b); }
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4f pmin<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) { return vminnmq_f32(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2f pmin<PropagateNumbers, Packet2f>(const Packet2f& a, const Packet2f& b) { return vminnm_f32(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4f pmin<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) { return pmin<Packet4f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pmin<PropagateNaN, Packet2f>(const Packet2f& a, const Packet2f& b) { return pmin<Packet2f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4c pmin<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmin_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmin<Packet8c>(const Packet8c& a, const Packet8c& b) { return vmin_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vminq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmin<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmin_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmin<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vmin_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vminq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmin<Packet4s>(const Packet4s& a, const Packet4s& b) { return vmin_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmin<Packet8s>(const Packet8s& a, const Packet8s& b) { return vminq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmin<Packet4us>(const Packet4us& a, const Packet4us& b) { return vmin_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, const Packet8us& b) { return vminq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmin<Packet2i>(const Packet2i& a, const Packet2i& b) { return vmin_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vminq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmin<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vmin_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmin<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vminq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pmin<Packet2l>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s64(
+ vdup_n_s64((std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(b, 0))),
+ vdup_n_s64((std::min)(vgetq_lane_s64(a, 1), vgetq_lane_s64(b, 1))));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pmin<Packet2ul>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u64(
+ vdup_n_u64((std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(b, 0))),
+ vdup_n_u64((std::min)(vgetq_lane_u64(a, 1), vgetq_lane_u64(b, 1))));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pmax<Packet2f>(const Packet2f& a, const Packet2f& b) { return vmax_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { return vmaxq_f32(a,b); }
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4f pmax<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) { return vmaxnmq_f32(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2f pmax<PropagateNumbers, Packet2f>(const Packet2f& a, const Packet2f& b) { return vmaxnm_f32(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4f pmax<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) { return pmax<Packet4f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pmax<PropagateNaN, Packet2f>(const Packet2f& a, const Packet2f& b) { return pmax<Packet2f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4c pmax<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmax_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmax<Packet8c>(const Packet8c& a, const Packet8c& b) { return vmax_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmax<Packet16c>(const Packet16c& a, const Packet16c& b) { return vmaxq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmax<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmax_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmax<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vmax_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vmaxq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmax<Packet4s>(const Packet4s& a, const Packet4s& b) { return vmax_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmax<Packet8s>(const Packet8s& a, const Packet8s& b) { return vmaxq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmax<Packet4us>(const Packet4us& a, const Packet4us& b) { return vmax_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmax<Packet8us>(const Packet8us& a, const Packet8us& b) { return vmaxq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmax<Packet2i>(const Packet2i& a, const Packet2i& b) { return vmax_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) { return vmaxq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmax<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vmax_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmax<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vmaxq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pmax<Packet2l>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s64(
+ vdup_n_s64((std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(b, 0))),
+ vdup_n_s64((std::max)(vgetq_lane_s64(a, 1), vgetq_lane_s64(b, 1))));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pmax<Packet2ul>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u64(
+ vdup_n_u64((std::max)(vgetq_lane_u64(a, 0), vgetq_lane_u64(b, 0))),
+ vdup_n_u64((std::max)(vgetq_lane_u64(a, 1), vgetq_lane_u64(b, 1))));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_le<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vcle_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_le<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vcleq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4c pcmp_le<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_u8(vcle_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pcmp_le<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vreinterpret_s8_u8(vcle_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_le<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vreinterpretq_s8_u8(vcleq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4uc pcmp_le<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vcle_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pcmp_le<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vcle_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_le<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vcleq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pcmp_le<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vreinterpret_s16_u16(vcle_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_le<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vreinterpretq_s16_u16(vcleq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4us pcmp_le<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vcle_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_le<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vcleq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pcmp_le<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vreinterpret_s32_u32(vcle_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_le<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vreinterpretq_s32_u32(vcleq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2ui pcmp_le<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vcle_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_le<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vcleq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pcmp_le<Packet2l>(const Packet2l& a, const Packet2l& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vreinterpretq_s64_u64(vcleq_s64(a,b));
+#else
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0) <= vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0),
+ vdup_n_s64(vgetq_lane_s64(a, 1) <= vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pcmp_le<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vcleq_u64(a,b);
+#else
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0) <= vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0),
+ vdup_n_u64(vgetq_lane_u64(a, 1) <= vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vclt_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vcltq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4c pcmp_lt<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_u8(vclt_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pcmp_lt<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vreinterpret_s8_u8(vclt_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_lt<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vreinterpretq_s8_u8(vcltq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4uc pcmp_lt<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vclt_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pcmp_lt<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vclt_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_lt<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vcltq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pcmp_lt<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vreinterpret_s16_u16(vclt_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_lt<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vreinterpretq_s16_u16(vcltq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4us pcmp_lt<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vclt_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_lt<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vcltq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pcmp_lt<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vreinterpret_s32_u32(vclt_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vreinterpretq_s32_u32(vcltq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2ui pcmp_lt<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vclt_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_lt<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vcltq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pcmp_lt<Packet2l>(const Packet2l& a, const Packet2l& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vreinterpretq_s64_u64(vcltq_s64(a,b));
+#else
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0) < vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0),
+ vdup_n_s64(vgetq_lane_s64(a, 1) < vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pcmp_lt<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vcltq_u64(a,b);
+#else
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0) < vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0),
+ vdup_n_u64(vgetq_lane_u64(a, 1) < vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_eq<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vceq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vceqq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4c pcmp_eq<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_u8(vceq_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pcmp_eq<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vreinterpret_s8_u8(vceq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_eq<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vreinterpretq_s8_u8(vceqq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4uc pcmp_eq<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vceq_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pcmp_eq<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vceq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_eq<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vceqq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pcmp_eq<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vreinterpret_s16_u16(vceq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_eq<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vreinterpretq_s16_u16(vceqq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4us pcmp_eq<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vceq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_eq<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vceqq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pcmp_eq<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vreinterpret_s32_u32(vceq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vreinterpretq_s32_u32(vceqq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2ui pcmp_eq<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vceq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_eq<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vceqq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pcmp_eq<Packet2l>(const Packet2l& a, const Packet2l& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vreinterpretq_s64_u64(vceqq_s64(a,b));
+#else
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0) == vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0),
+ vdup_n_s64(vgetq_lane_s64(a, 1) == vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pcmp_eq<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vceqq_u64(a,b);
+#else
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0) == vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0),
+ vdup_n_u64(vgetq_lane_u64(a, 1) == vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt_or_nan<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vmvn_u32(vcge_f32(a,b))); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vmvnq_u32(vcgeq_f32(a,b))); }
// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
+template<> EIGEN_STRONG_INLINE Packet2f pand<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b)
-{
- return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
-}
+{ return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c pand<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a & b; }
+template<> EIGEN_STRONG_INLINE Packet8c pand<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vand_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pand<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vandq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pand<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a & b; }
+template<> EIGEN_STRONG_INLINE Packet8uc pand<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vand_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pand<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vandq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pand<Packet4s>(const Packet4s& a, const Packet4s& b) { return vand_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pand<Packet8s>(const Packet8s& a, const Packet8s& b) { return vandq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pand<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vand_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vandq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pand<Packet2i>(const Packet2i& a, const Packet2i& b) { return vand_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vandq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pand<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vand_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vandq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pand<Packet2l>(const Packet2l& a, const Packet2l& b) { return vandq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul pand<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return vandq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f por<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vorr_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b)
-{
- return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
-}
+{ return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c por<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a | b; }
+template<> EIGEN_STRONG_INLINE Packet8c por<Packet8c>(const Packet8c& a, const Packet8c& b) { return vorr_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c por<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vorrq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc por<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a | b; }
+template<> EIGEN_STRONG_INLINE Packet8uc por<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vorr_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc por<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vorrq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s por<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vorr_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vorrq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us por<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vorr_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vorrq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i por<Packet2i>(const Packet2i& a, const Packet2i& b) { return vorr_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vorrq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui por<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vorr_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui por<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vorrq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l por<Packet2l>(const Packet2l& a, const Packet2l& b)
+{ return vorrq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul por<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return vorrq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pxor<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b)
-{
- return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
-}
+{ return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c pxor<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a ^ b; }
+template<> EIGEN_STRONG_INLINE Packet8c pxor<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return veor_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pxor<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return veorq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pxor<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a ^ b; }
+template<> EIGEN_STRONG_INLINE Packet8uc pxor<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return veor_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pxor<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return veorq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pxor<Packet4s>(const Packet4s& a, const Packet4s& b) { return veor_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pxor<Packet8s>(const Packet8s& a, const Packet8s& b) { return veorq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pxor<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return veor_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pxor<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return veorq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pxor<Packet2i>(const Packet2i& a, const Packet2i& b) { return veor_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return veorq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pxor<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return veor_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pxor<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return veorq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pxor<Packet2l>(const Packet2l& a, const Packet2l& b)
+{ return veorq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul pxor<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return veorq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pandnot<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vbic_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c pandnot<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a & ~b; }
+template<> EIGEN_STRONG_INLINE Packet8c pandnot<Packet8c>(const Packet8c& a, const Packet8c& b) { return vbic_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pandnot<Packet16c>(const Packet16c& a, const Packet16c& b) { return vbicq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pandnot<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a & ~b; }
+template<> EIGEN_STRONG_INLINE Packet8uc pandnot<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vbic_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pandnot<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vbicq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pandnot<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vbic_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pandnot<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vbicq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pandnot<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vbic_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pandnot<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vbicq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pandnot<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vbic_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vbicq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pandnot<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vbic_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pandnot<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vbicq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pandnot<Packet2l>(const Packet2l& a, const Packet2l& b)
+{ return vbicq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul pandnot<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return vbicq_u64(a,b); }
+
+
+template<int N> EIGEN_STRONG_INLINE Packet4c parithmetic_shift_right(Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vshr_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8c parithmetic_shift_right(Packet8c a) { return vshr_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16c parithmetic_shift_right(Packet16c a) { return vshrq_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4uc parithmetic_shift_right(Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vshr_n_u8(vreinterpret_u8_u32(vdup_n_u32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8uc parithmetic_shift_right(Packet8uc a) { return vshr_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16uc parithmetic_shift_right(Packet16uc a) { return vshrq_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4s parithmetic_shift_right(Packet4s a) { return vshr_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) { return vshrq_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4us parithmetic_shift_right(Packet4us a) { return vshr_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8us parithmetic_shift_right(Packet8us a) { return vshrq_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2i parithmetic_shift_right(Packet2i a) { return vshr_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a) { return vshrq_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ui parithmetic_shift_right(Packet2ui a) { return vshr_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui parithmetic_shift_right(Packet4ui a) { return vshrq_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2l parithmetic_shift_right(Packet2l a) { return vshrq_n_s64(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ul parithmetic_shift_right(Packet2ul a) { return vshrq_n_u64(a,N); }
+
+template<int N> EIGEN_STRONG_INLINE Packet4c plogical_shift_right(Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_u8(vshr_n_u8(vreinterpret_u8_s32(vdup_n_s32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8c plogical_shift_right(Packet8c a)
+{ return vreinterpret_s8_u8(vshr_n_u8(vreinterpret_u8_s8(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet16c plogical_shift_right(Packet16c a)
+{ return vreinterpretq_s8_u8(vshrq_n_u8(vreinterpretq_u8_s8(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet4uc plogical_shift_right(Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_s8(vshr_n_s8(vreinterpret_s8_u32(vdup_n_u32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8uc plogical_shift_right(Packet8uc a) { return vshr_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16uc plogical_shift_right(Packet16uc a) { return vshrq_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4s plogical_shift_right(Packet4s a)
+{ return vreinterpret_s16_u16(vshr_n_u16(vreinterpret_u16_s16(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a)
+{ return vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_s16(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet4us plogical_shift_right(Packet4us a) { return vshr_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_right(Packet8us a) { return vshrq_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2i plogical_shift_right(Packet2i a)
+{ return vreinterpret_s32_u32(vshr_n_u32(vreinterpret_u32_s32(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a)
+{ return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet2ui plogical_shift_right(Packet2ui a) { return vshr_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a) { return vshrq_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2l plogical_shift_right(Packet2l a)
+{ return vreinterpretq_s64_u64(vshrq_n_u64(vreinterpretq_u64_s64(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet2ul plogical_shift_right(Packet2ul a) { return vshrq_n_u64(a,N); }
+
+template<int N> EIGEN_STRONG_INLINE Packet4c plogical_shift_left(Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vshl_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8c plogical_shift_left(Packet8c a) { return vshl_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16c plogical_shift_left(Packet16c a) { return vshlq_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4uc plogical_shift_left(Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vshl_n_u8(vreinterpret_u8_u32(vdup_n_u32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8uc plogical_shift_left(Packet8uc a) { return vshl_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16uc plogical_shift_left(Packet16uc a) { return vshlq_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4s plogical_shift_left(Packet4s a) { return vshl_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) { return vshlq_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4us plogical_shift_left(Packet4us a) { return vshl_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a) { return vshlq_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2i plogical_shift_left(Packet2i a) { return vshl_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a) { return vshlq_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ui plogical_shift_left(Packet2ui a) { return vshl_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a) { return vshlq_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2l plogical_shift_left(Packet2l a) { return vshlq_n_s64(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ul plogical_shift_left(Packet2ul a) { return vshlq_n_u64(a,N); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pload<Packet2f>(const float* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c pload<Packet4c>(const int8_t* from)
{
- return vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
+ Packet4c res;
+ memcpy(&res, from, sizeof(Packet4c));
+ return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vbicq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8c pload<Packet8c>(const int8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet16c pload<Packet16c>(const int8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet4uc pload<Packet4uc>(const uint8_t* from)
+{
+ Packet4uc res;
+ memcpy(&res, from, sizeof(Packet4uc));
+ return res;
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pload<Packet8uc>(const uint8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const uint8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet4s pload<Packet4s>(const int16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet8s pload<Packet8s>(const int16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet4us pload<Packet4us>(const uint16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet8us pload<Packet8us>(const uint16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet2i pload<Packet2i>(const int32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2ui pload<Packet2ui>(const uint32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui pload<Packet4ui>(const uint32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet2l pload<Packet2l>(const int64_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul pload<Packet2ul>(const uint64_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u64(from); }
-template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f32(from); }
-template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int32_t* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2f ploadu<Packet2f>(const float* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c ploadu<Packet4c>(const int8_t* from)
+{
+ Packet4c res;
+ memcpy(&res, from, sizeof(Packet4c));
+ return res;
+}
+template<> EIGEN_STRONG_INLINE Packet8c ploadu<Packet8c>(const int8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const int8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet4uc ploadu<Packet4uc>(const uint8_t* from)
+{
+ Packet4uc res;
+ memcpy(&res, from, sizeof(Packet4uc));
+ return res;
+}
+template<> EIGEN_STRONG_INLINE Packet8uc ploadu<Packet8uc>(const uint8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet16uc ploadu<Packet16uc>(const uint8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet4s ploadu<Packet4s>(const int16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet8s ploadu<Packet8s>(const int16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet4us ploadu<Packet4us>(const uint16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const uint16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet2i ploadu<Packet2i>(const int32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2ui ploadu<Packet2ui>(const uint32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui ploadu<Packet4ui>(const uint32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet2l ploadu<Packet2l>(const int64_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul ploadu<Packet2ul>(const uint64_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u64(from); }
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f32(from); }
-template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int32_t* from) { EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s32(from); }
-
+template<> EIGEN_STRONG_INLINE Packet2f ploaddup<Packet2f>(const float* from)
+{ return vld1_dup_f32(from); }
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+{ return vcombine_f32(vld1_dup_f32(from), vld1_dup_f32(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet4c ploaddup<Packet4c>(const int8_t* from)
{
- float32x2_t lo, hi;
- lo = vld1_dup_f32(from);
- hi = vld1_dup_f32(from+1);
- return vcombine_f32(lo, hi);
+ const int8x8_t a = vreinterpret_s8_s32(vdup_n_s32(pload<Packet4c>(from)));
+ return vget_lane_s32(vreinterpret_s32_s8(vzip_s8(a,a).val[0]), 0);
}
+template<> EIGEN_STRONG_INLINE Packet8c ploaddup<Packet8c>(const int8_t* from)
+{
+ const int8x8_t a = vld1_s8(from);
+ return vzip_s8(a,a).val[0];
+}
+template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const int8_t* from)
+{
+ const int8x8_t a = vld1_s8(from);
+ const int8x8x2_t b = vzip_s8(a,a);
+ return vcombine_s8(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet4uc ploaddup<Packet4uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vreinterpret_u8_u32(vdup_n_u32(pload<Packet4uc>(from)));
+ return vget_lane_u32(vreinterpret_u32_u8(vzip_u8(a,a).val[0]), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc ploaddup<Packet8uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vld1_u8(from);
+ return vzip_u8(a,a).val[0];
+}
+template<> EIGEN_STRONG_INLINE Packet16uc ploaddup<Packet16uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vld1_u8(from);
+ const uint8x8x2_t b = vzip_u8(a,a);
+ return vcombine_u8(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet4s ploaddup<Packet4s>(const int16_t* from)
+{
+ return vreinterpret_s16_u32(vzip_u32(vreinterpret_u32_s16(vld1_dup_s16(from)),
+ vreinterpret_u32_s16(vld1_dup_s16(from+1))).val[0]);
+}
+template<> EIGEN_STRONG_INLINE Packet8s ploaddup<Packet8s>(const int16_t* from)
+{
+ const int16x4_t a = vld1_s16(from);
+ const int16x4x2_t b = vzip_s16(a,a);
+ return vcombine_s16(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet4us ploaddup<Packet4us>(const uint16_t* from)
+{
+ return vreinterpret_u16_u32(vzip_u32(vreinterpret_u32_u16(vld1_dup_u16(from)),
+ vreinterpret_u32_u16(vld1_dup_u16(from+1))).val[0]);
+}
+template<> EIGEN_STRONG_INLINE Packet8us ploaddup<Packet8us>(const uint16_t* from)
+{
+ const uint16x4_t a = vld1_u16(from);
+ const uint16x4x2_t b = vzip_u16(a,a);
+ return vcombine_u16(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet2i ploaddup<Packet2i>(const int32_t* from)
+{ return vld1_dup_s32(from); }
template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int32_t* from)
+{ return vcombine_s32(vld1_dup_s32(from), vld1_dup_s32(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet2ui ploaddup<Packet2ui>(const uint32_t* from)
+{ return vld1_dup_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui ploaddup<Packet4ui>(const uint32_t* from)
+{ return vcombine_u32(vld1_dup_u32(from), vld1_dup_u32(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet2l ploaddup<Packet2l>(const int64_t* from)
+{ return vld1q_dup_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul ploaddup<Packet2ul>(const uint64_t* from)
+{ return vld1q_dup_u64(from); }
+
+template<> EIGEN_STRONG_INLINE Packet4f ploadquad<Packet4f>(const float* from) { return vld1q_dup_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c ploadquad<Packet4c>(const int8_t* from)
+{ return vget_lane_s32(vreinterpret_s32_s8(vld1_dup_s8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c ploadquad<Packet8c>(const int8_t* from)
{
- int32x2_t lo, hi;
- lo = vld1_dup_s32(from);
- hi = vld1_dup_s32(from+1);
- return vcombine_s32(lo, hi);
+ return vreinterpret_s8_u32(vzip_u32(
+ vreinterpret_u32_s8(vld1_dup_s8(from)),
+ vreinterpret_u32_s8(vld1_dup_s8(from+1))).val[0]);
}
-
-template<> EIGEN_STRONG_INLINE void pstore<float> (float* to, const Packet4f& from) { EIGEN_DEBUG_ALIGNED_STORE vst1q_f32(to, from); }
-template<> EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet4i& from) { EIGEN_DEBUG_ALIGNED_STORE vst1q_s32(to, from); }
-
-template<> EIGEN_STRONG_INLINE void pstoreu<float> (float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE vst1q_f32(to, from); }
-template<> EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE vst1q_s32(to, from); }
-
-template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
+template<> EIGEN_STRONG_INLINE Packet16c ploadquad<Packet16c>(const int8_t* from)
{
- Packet4f res = pset1<Packet4f>(0.f);
- res = vsetq_lane_f32(from[0*stride], res, 0);
- res = vsetq_lane_f32(from[1*stride], res, 1);
- res = vsetq_lane_f32(from[2*stride], res, 2);
- res = vsetq_lane_f32(from[3*stride], res, 3);
+ const int8x8_t a = vreinterpret_s8_u32(vzip_u32(
+ vreinterpret_u32_s8(vld1_dup_s8(from)),
+ vreinterpret_u32_s8(vld1_dup_s8(from+1))).val[0]);
+ const int8x8_t b = vreinterpret_s8_u32(vzip_u32(
+ vreinterpret_u32_s8(vld1_dup_s8(from+2)),
+ vreinterpret_u32_s8(vld1_dup_s8(from+3))).val[0]);
+ return vcombine_s8(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet4uc ploadquad<Packet4uc>(const uint8_t* from)
+{ return vget_lane_u32(vreinterpret_u32_u8(vld1_dup_u8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc ploadquad<Packet8uc>(const uint8_t* from)
+{
+ return vreinterpret_u8_u32(vzip_u32(
+ vreinterpret_u32_u8(vld1_dup_u8(from)),
+ vreinterpret_u32_u8(vld1_dup_u8(from+1))).val[0]);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc ploadquad<Packet16uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vreinterpret_u8_u32(vzip_u32(
+ vreinterpret_u32_u8(vld1_dup_u8(from)),
+ vreinterpret_u32_u8(vld1_dup_u8(from+1))).val[0]);
+ const uint8x8_t b = vreinterpret_u8_u32(vzip_u32(
+ vreinterpret_u32_u8(vld1_dup_u8(from+2)),
+ vreinterpret_u32_u8(vld1_dup_u8(from+3))).val[0]);
+ return vcombine_u8(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8s ploadquad<Packet8s>(const int16_t* from)
+{ return vcombine_s16(vld1_dup_s16(from), vld1_dup_s16(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const uint16_t* from)
+{ return vcombine_u16(vld1_dup_u16(from), vld1_dup_u16(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet4i ploadquad<Packet4i>(const int32_t* from) { return vld1q_dup_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui ploadquad<Packet4ui>(const uint32_t* from) { return vld1q_dup_u32(from); }
+
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet2f& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet4c& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet8c& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet4uc& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet8uc& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet16uc& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int16_t>(int16_t* to, const Packet4s& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int16_t>(int16_t* to, const Packet8s& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint16_t>(uint16_t* to, const Packet4us& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint16_t>(uint16_t* to, const Packet8us& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet2i& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet4i& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet2ui& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet4ui& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int64_t>(int64_t* to, const Packet2l& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s64(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint64_t>(uint64_t* to, const Packet2ul& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u64(to,from); }
+
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet2f& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet4c& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet8c& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet16c& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint8_t>(uint8_t* to, const Packet4uc& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint8_t>(uint8_t* to, const Packet8uc& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint8_t>(uint8_t* to, const Packet16uc& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int16_t>(int16_t* to, const Packet4s& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int16_t>(int16_t* to, const Packet8s& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint16_t>(uint16_t* to, const Packet4us& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint16_t>(uint16_t* to, const Packet8us& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet2i& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet4i& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint32_t>(uint32_t* to, const Packet2ui& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint32_t>(uint32_t* to, const Packet4ui& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int64_t>(int64_t* to, const Packet2l& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s64(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint64_t>(uint64_t* to, const Packet2ul& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u64(to,from); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pgather<float, Packet2f>(const float* from, Index stride)
+{
+ Packet2f res = vld1_dup_f32(from);
+ res = vld1_lane_f32(from + 1*stride, res, 1);
return res;
}
-template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int32_t, Packet4i>(const int32_t* from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{
- Packet4i res = pset1<Packet4i>(0);
- res = vsetq_lane_s32(from[0*stride], res, 0);
- res = vsetq_lane_s32(from[1*stride], res, 1);
- res = vsetq_lane_s32(from[2*stride], res, 2);
- res = vsetq_lane_s32(from[3*stride], res, 3);
+ Packet4f res = vld1q_dup_f32(from);
+ res = vld1q_lane_f32(from + 1*stride, res, 1);
+ res = vld1q_lane_f32(from + 2*stride, res, 2);
+ res = vld1q_lane_f32(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c pgather<int8_t, Packet4c>(const int8_t* from, Index stride)
+{
+ Packet4c res;
+ for (int i = 0; i != 4; i++)
+ reinterpret_cast<int8_t*>(&res)[i] = *(from + i * stride);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pgather<int8_t, Packet8c>(const int8_t* from, Index stride)
+{
+ Packet8c res = vld1_dup_s8(from);
+ res = vld1_lane_s8(from + 1*stride, res, 1);
+ res = vld1_lane_s8(from + 2*stride, res, 2);
+ res = vld1_lane_s8(from + 3*stride, res, 3);
+ res = vld1_lane_s8(from + 4*stride, res, 4);
+ res = vld1_lane_s8(from + 5*stride, res, 5);
+ res = vld1_lane_s8(from + 6*stride, res, 6);
+ res = vld1_lane_s8(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pgather<int8_t, Packet16c>(const int8_t* from, Index stride)
+{
+ Packet16c res = vld1q_dup_s8(from);
+ res = vld1q_lane_s8(from + 1*stride, res, 1);
+ res = vld1q_lane_s8(from + 2*stride, res, 2);
+ res = vld1q_lane_s8(from + 3*stride, res, 3);
+ res = vld1q_lane_s8(from + 4*stride, res, 4);
+ res = vld1q_lane_s8(from + 5*stride, res, 5);
+ res = vld1q_lane_s8(from + 6*stride, res, 6);
+ res = vld1q_lane_s8(from + 7*stride, res, 7);
+ res = vld1q_lane_s8(from + 8*stride, res, 8);
+ res = vld1q_lane_s8(from + 9*stride, res, 9);
+ res = vld1q_lane_s8(from + 10*stride, res, 10);
+ res = vld1q_lane_s8(from + 11*stride, res, 11);
+ res = vld1q_lane_s8(from + 12*stride, res, 12);
+ res = vld1q_lane_s8(from + 13*stride, res, 13);
+ res = vld1q_lane_s8(from + 14*stride, res, 14);
+ res = vld1q_lane_s8(from + 15*stride, res, 15);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc pgather<uint8_t, Packet4uc>(const uint8_t* from, Index stride)
+{
+ Packet4uc res;
+ for (int i = 0; i != 4; i++)
+ reinterpret_cast<uint8_t*>(&res)[i] = *(from + i * stride);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pgather<uint8_t, Packet8uc>(const uint8_t* from, Index stride)
+{
+ Packet8uc res = vld1_dup_u8(from);
+ res = vld1_lane_u8(from + 1*stride, res, 1);
+ res = vld1_lane_u8(from + 2*stride, res, 2);
+ res = vld1_lane_u8(from + 3*stride, res, 3);
+ res = vld1_lane_u8(from + 4*stride, res, 4);
+ res = vld1_lane_u8(from + 5*stride, res, 5);
+ res = vld1_lane_u8(from + 6*stride, res, 6);
+ res = vld1_lane_u8(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pgather<uint8_t, Packet16uc>(const uint8_t* from, Index stride)
+{
+ Packet16uc res = vld1q_dup_u8(from);
+ res = vld1q_lane_u8(from + 1*stride, res, 1);
+ res = vld1q_lane_u8(from + 2*stride, res, 2);
+ res = vld1q_lane_u8(from + 3*stride, res, 3);
+ res = vld1q_lane_u8(from + 4*stride, res, 4);
+ res = vld1q_lane_u8(from + 5*stride, res, 5);
+ res = vld1q_lane_u8(from + 6*stride, res, 6);
+ res = vld1q_lane_u8(from + 7*stride, res, 7);
+ res = vld1q_lane_u8(from + 8*stride, res, 8);
+ res = vld1q_lane_u8(from + 9*stride, res, 9);
+ res = vld1q_lane_u8(from + 10*stride, res, 10);
+ res = vld1q_lane_u8(from + 11*stride, res, 11);
+ res = vld1q_lane_u8(from + 12*stride, res, 12);
+ res = vld1q_lane_u8(from + 13*stride, res, 13);
+ res = vld1q_lane_u8(from + 14*stride, res, 14);
+ res = vld1q_lane_u8(from + 15*stride, res, 15);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pgather<int16_t, Packet4s>(const int16_t* from, Index stride)
+{
+ Packet4s res = vld1_dup_s16(from);
+ res = vld1_lane_s16(from + 1*stride, res, 1);
+ res = vld1_lane_s16(from + 2*stride, res, 2);
+ res = vld1_lane_s16(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pgather<int16_t, Packet8s>(const int16_t* from, Index stride)
+{
+ Packet8s res = vld1q_dup_s16(from);
+ res = vld1q_lane_s16(from + 1*stride, res, 1);
+ res = vld1q_lane_s16(from + 2*stride, res, 2);
+ res = vld1q_lane_s16(from + 3*stride, res, 3);
+ res = vld1q_lane_s16(from + 4*stride, res, 4);
+ res = vld1q_lane_s16(from + 5*stride, res, 5);
+ res = vld1q_lane_s16(from + 6*stride, res, 6);
+ res = vld1q_lane_s16(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pgather<uint16_t, Packet4us>(const uint16_t* from, Index stride)
+{
+ Packet4us res = vld1_dup_u16(from);
+ res = vld1_lane_u16(from + 1*stride, res, 1);
+ res = vld1_lane_u16(from + 2*stride, res, 2);
+ res = vld1_lane_u16(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pgather<uint16_t, Packet8us>(const uint16_t* from, Index stride)
+{
+ Packet8us res = vld1q_dup_u16(from);
+ res = vld1q_lane_u16(from + 1*stride, res, 1);
+ res = vld1q_lane_u16(from + 2*stride, res, 2);
+ res = vld1q_lane_u16(from + 3*stride, res, 3);
+ res = vld1q_lane_u16(from + 4*stride, res, 4);
+ res = vld1q_lane_u16(from + 5*stride, res, 5);
+ res = vld1q_lane_u16(from + 6*stride, res, 6);
+ res = vld1q_lane_u16(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pgather<int32_t, Packet2i>(const int32_t* from, Index stride)
+{
+ Packet2i res = vld1_dup_s32(from);
+ res = vld1_lane_s32(from + 1*stride, res, 1);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pgather<int32_t, Packet4i>(const int32_t* from, Index stride)
+{
+ Packet4i res = vld1q_dup_s32(from);
+ res = vld1q_lane_s32(from + 1*stride, res, 1);
+ res = vld1q_lane_s32(from + 2*stride, res, 2);
+ res = vld1q_lane_s32(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pgather<uint32_t, Packet2ui>(const uint32_t* from, Index stride)
+{
+ Packet2ui res = vld1_dup_u32(from);
+ res = vld1_lane_u32(from + 1*stride, res, 1);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pgather<uint32_t, Packet4ui>(const uint32_t* from, Index stride)
+{
+ Packet4ui res = vld1q_dup_u32(from);
+ res = vld1q_lane_u32(from + 1*stride, res, 1);
+ res = vld1q_lane_u32(from + 2*stride, res, 2);
+ res = vld1q_lane_u32(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pgather<int64_t, Packet2l>(const int64_t* from, Index stride)
+{
+ Packet2l res = vld1q_dup_s64(from);
+ res = vld1q_lane_s64(from + 1*stride, res, 1);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pgather<uint64_t, Packet2ul>(const uint64_t* from, Index stride)
+{
+ Packet2ul res = vld1q_dup_u64(from);
+ res = vld1q_lane_u64(from + 1*stride, res, 1);
return res;
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<float, Packet2f>(float* to, const Packet2f& from, Index stride)
{
- to[stride*0] = vgetq_lane_f32(from, 0);
- to[stride*1] = vgetq_lane_f32(from, 1);
- to[stride*2] = vgetq_lane_f32(from, 2);
- to[stride*3] = vgetq_lane_f32(from, 3);
+ vst1_lane_f32(to + stride*0, from, 0);
+ vst1_lane_f32(to + stride*1, from, 1);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<int32_t, Packet4i>(int32_t* to, const Packet4i& from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
{
- to[stride*0] = vgetq_lane_s32(from, 0);
- to[stride*1] = vgetq_lane_s32(from, 1);
- to[stride*2] = vgetq_lane_s32(from, 2);
- to[stride*3] = vgetq_lane_s32(from, 3);
+ vst1q_lane_f32(to + stride*0, from, 0);
+ vst1q_lane_f32(to + stride*1, from, 1);
+ vst1q_lane_f32(to + stride*2, from, 2);
+ vst1q_lane_f32(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int8_t, Packet4c>(int8_t* to, const Packet4c& from, Index stride)
+{
+ for (int i = 0; i != 4; i++)
+ *(to + i * stride) = reinterpret_cast<const int8_t*>(&from)[i];
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int8_t, Packet8c>(int8_t* to, const Packet8c& from, Index stride)
+{
+ vst1_lane_s8(to + stride*0, from, 0);
+ vst1_lane_s8(to + stride*1, from, 1);
+ vst1_lane_s8(to + stride*2, from, 2);
+ vst1_lane_s8(to + stride*3, from, 3);
+ vst1_lane_s8(to + stride*4, from, 4);
+ vst1_lane_s8(to + stride*5, from, 5);
+ vst1_lane_s8(to + stride*6, from, 6);
+ vst1_lane_s8(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int8_t, Packet16c>(int8_t* to, const Packet16c& from, Index stride)
+{
+ vst1q_lane_s8(to + stride*0, from, 0);
+ vst1q_lane_s8(to + stride*1, from, 1);
+ vst1q_lane_s8(to + stride*2, from, 2);
+ vst1q_lane_s8(to + stride*3, from, 3);
+ vst1q_lane_s8(to + stride*4, from, 4);
+ vst1q_lane_s8(to + stride*5, from, 5);
+ vst1q_lane_s8(to + stride*6, from, 6);
+ vst1q_lane_s8(to + stride*7, from, 7);
+ vst1q_lane_s8(to + stride*8, from, 8);
+ vst1q_lane_s8(to + stride*9, from, 9);
+ vst1q_lane_s8(to + stride*10, from, 10);
+ vst1q_lane_s8(to + stride*11, from, 11);
+ vst1q_lane_s8(to + stride*12, from, 12);
+ vst1q_lane_s8(to + stride*13, from, 13);
+ vst1q_lane_s8(to + stride*14, from, 14);
+ vst1q_lane_s8(to + stride*15, from, 15);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint8_t, Packet4uc>(uint8_t* to, const Packet4uc& from, Index stride)
+{
+ for (int i = 0; i != 4; i++)
+ *(to + i * stride) = reinterpret_cast<const uint8_t*>(&from)[i];
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint8_t, Packet8uc>(uint8_t* to, const Packet8uc& from, Index stride)
+{
+ vst1_lane_u8(to + stride*0, from, 0);
+ vst1_lane_u8(to + stride*1, from, 1);
+ vst1_lane_u8(to + stride*2, from, 2);
+ vst1_lane_u8(to + stride*3, from, 3);
+ vst1_lane_u8(to + stride*4, from, 4);
+ vst1_lane_u8(to + stride*5, from, 5);
+ vst1_lane_u8(to + stride*6, from, 6);
+ vst1_lane_u8(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint8_t, Packet16uc>(uint8_t* to, const Packet16uc& from, Index stride)
+{
+ vst1q_lane_u8(to + stride*0, from, 0);
+ vst1q_lane_u8(to + stride*1, from, 1);
+ vst1q_lane_u8(to + stride*2, from, 2);
+ vst1q_lane_u8(to + stride*3, from, 3);
+ vst1q_lane_u8(to + stride*4, from, 4);
+ vst1q_lane_u8(to + stride*5, from, 5);
+ vst1q_lane_u8(to + stride*6, from, 6);
+ vst1q_lane_u8(to + stride*7, from, 7);
+ vst1q_lane_u8(to + stride*8, from, 8);
+ vst1q_lane_u8(to + stride*9, from, 9);
+ vst1q_lane_u8(to + stride*10, from, 10);
+ vst1q_lane_u8(to + stride*11, from, 11);
+ vst1q_lane_u8(to + stride*12, from, 12);
+ vst1q_lane_u8(to + stride*13, from, 13);
+ vst1q_lane_u8(to + stride*14, from, 14);
+ vst1q_lane_u8(to + stride*15, from, 15);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int16_t, Packet4s>(int16_t* to, const Packet4s& from, Index stride)
+{
+ vst1_lane_s16(to + stride*0, from, 0);
+ vst1_lane_s16(to + stride*1, from, 1);
+ vst1_lane_s16(to + stride*2, from, 2);
+ vst1_lane_s16(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int16_t, Packet8s>(int16_t* to, const Packet8s& from, Index stride)
+{
+ vst1q_lane_s16(to + stride*0, from, 0);
+ vst1q_lane_s16(to + stride*1, from, 1);
+ vst1q_lane_s16(to + stride*2, from, 2);
+ vst1q_lane_s16(to + stride*3, from, 3);
+ vst1q_lane_s16(to + stride*4, from, 4);
+ vst1q_lane_s16(to + stride*5, from, 5);
+ vst1q_lane_s16(to + stride*6, from, 6);
+ vst1q_lane_s16(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint16_t, Packet4us>(uint16_t* to, const Packet4us& from, Index stride)
+{
+ vst1_lane_u16(to + stride*0, from, 0);
+ vst1_lane_u16(to + stride*1, from, 1);
+ vst1_lane_u16(to + stride*2, from, 2);
+ vst1_lane_u16(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint16_t, Packet8us>(uint16_t* to, const Packet8us& from, Index stride)
+{
+ vst1q_lane_u16(to + stride*0, from, 0);
+ vst1q_lane_u16(to + stride*1, from, 1);
+ vst1q_lane_u16(to + stride*2, from, 2);
+ vst1q_lane_u16(to + stride*3, from, 3);
+ vst1q_lane_u16(to + stride*4, from, 4);
+ vst1q_lane_u16(to + stride*5, from, 5);
+ vst1q_lane_u16(to + stride*6, from, 6);
+ vst1q_lane_u16(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int32_t, Packet2i>(int32_t* to, const Packet2i& from, Index stride)
+{
+ vst1_lane_s32(to + stride*0, from, 0);
+ vst1_lane_s32(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int32_t, Packet4i>(int32_t* to, const Packet4i& from, Index stride)
+{
+ vst1q_lane_s32(to + stride*0, from, 0);
+ vst1q_lane_s32(to + stride*1, from, 1);
+ vst1q_lane_s32(to + stride*2, from, 2);
+ vst1q_lane_s32(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint32_t, Packet2ui>(uint32_t* to, const Packet2ui& from, Index stride)
+{
+ vst1_lane_u32(to + stride*0, from, 0);
+ vst1_lane_u32(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint32_t, Packet4ui>(uint32_t* to, const Packet4ui& from, Index stride)
+{
+ vst1q_lane_u32(to + stride*0, from, 0);
+ vst1q_lane_u32(to + stride*1, from, 1);
+ vst1q_lane_u32(to + stride*2, from, 2);
+ vst1q_lane_u32(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int64_t, Packet2l>(int64_t* to, const Packet2l& from, Index stride)
+{
+ vst1q_lane_s64(to + stride*0, from, 0);
+ vst1q_lane_s64(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint64_t, Packet2ul>(uint64_t* to, const Packet2ul& from, Index stride)
+{
+ vst1q_lane_u64(to + stride*0, from, 0);
+ vst1q_lane_u64(to + stride*1, from, 1);
}
-template<> EIGEN_STRONG_INLINE void prefetch<float> (const float* addr) { EIGEN_ARM_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE void prefetch<int32_t>(const int32_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int8_t>(const int8_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint8_t>(const uint8_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int16_t>(const int16_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint16_t>(const uint16_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int32_t>(const int32_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint32_t>(const uint32_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int64_t>(const int64_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint64_t>(const uint64_t* addr) { EIGEN_ARM_PREFETCH(addr); }
-// FIXME only store the 2 first elements ?
-template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float EIGEN_ALIGN16 x[4]; vst1q_f32(x, a); return x[0]; }
-template<> EIGEN_STRONG_INLINE int32_t pfirst<Packet4i>(const Packet4i& a) { int32_t EIGEN_ALIGN16 x[4]; vst1q_s32(x, a); return x[0]; }
+template<> EIGEN_STRONG_INLINE float pfirst<Packet2f>(const Packet2f& a) { return vget_lane_f32(a,0); }
+template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { return vgetq_lane_f32(a,0); }
+template<> EIGEN_STRONG_INLINE int8_t pfirst<Packet4c>(const Packet4c& a) { return static_cast<int8_t>(a & 0xff); }
+template<> EIGEN_STRONG_INLINE int8_t pfirst<Packet8c>(const Packet8c& a) { return vget_lane_s8(a,0); }
+template<> EIGEN_STRONG_INLINE int8_t pfirst<Packet16c>(const Packet16c& a) { return vgetq_lane_s8(a,0); }
+template<> EIGEN_STRONG_INLINE uint8_t pfirst<Packet4uc>(const Packet4uc& a) { return static_cast<uint8_t>(a & 0xff); }
+template<> EIGEN_STRONG_INLINE uint8_t pfirst<Packet8uc>(const Packet8uc& a) { return vget_lane_u8(a,0); }
+template<> EIGEN_STRONG_INLINE uint8_t pfirst<Packet16uc>(const Packet16uc& a) { return vgetq_lane_u8(a,0); }
+template<> EIGEN_STRONG_INLINE int16_t pfirst<Packet4s>(const Packet4s& a) { return vget_lane_s16(a,0); }
+template<> EIGEN_STRONG_INLINE int16_t pfirst<Packet8s>(const Packet8s& a) { return vgetq_lane_s16(a,0); }
+template<> EIGEN_STRONG_INLINE uint16_t pfirst<Packet4us>(const Packet4us& a) { return vget_lane_u16(a,0); }
+template<> EIGEN_STRONG_INLINE uint16_t pfirst<Packet8us>(const Packet8us& a) { return vgetq_lane_u16(a,0); }
+template<> EIGEN_STRONG_INLINE int32_t pfirst<Packet2i>(const Packet2i& a) { return vget_lane_s32(a,0); }
+template<> EIGEN_STRONG_INLINE int32_t pfirst<Packet4i>(const Packet4i& a) { return vgetq_lane_s32(a,0); }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet2ui>(const Packet2ui& a) { return vget_lane_u32(a,0); }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet4ui>(const Packet4ui& a) { return vgetq_lane_u32(a,0); }
+template<> EIGEN_STRONG_INLINE int64_t pfirst<Packet2l>(const Packet2l& a) { return vgetq_lane_s64(a,0); }
+template<> EIGEN_STRONG_INLINE uint64_t pfirst<Packet2ul>(const Packet2ul& a) { return vgetq_lane_u64(a,0); }
-template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) {
- float32x2_t a_lo, a_hi;
- Packet4f a_r64;
-
- a_r64 = vrev64q_f32(a);
- a_lo = vget_low_f32(a_r64);
- a_hi = vget_high_f32(a_r64);
- return vcombine_f32(a_hi, a_lo);
+template<> EIGEN_STRONG_INLINE Packet2f preverse(const Packet2f& a) { return vrev64_f32(a); }
+template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
+{
+ const float32x4_t a_r64 = vrev64q_f32(a);
+ return vcombine_f32(vget_high_f32(a_r64), vget_low_f32(a_r64));
}
-template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) {
- int32x2_t a_lo, a_hi;
- Packet4i a_r64;
-
- a_r64 = vrev64q_s32(a);
- a_lo = vget_low_s32(a_r64);
- a_hi = vget_high_s32(a_r64);
- return vcombine_s32(a_hi, a_lo);
+template<> EIGEN_STRONG_INLINE Packet4c preverse(const Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vrev64_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c preverse(const Packet8c& a) { return vrev64_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a)
+{
+ const int8x16_t a_r64 = vrev64q_s8(a);
+ return vcombine_s8(vget_high_s8(a_r64), vget_low_s8(a_r64));
}
+template<> EIGEN_STRONG_INLINE Packet4uc preverse(const Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vrev64_u8(vreinterpret_u8_u32(vdup_n_u32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc preverse(const Packet8uc& a) { return vrev64_u8(a); }
+template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
+{
+ const uint8x16_t a_r64 = vrev64q_u8(a);
+ return vcombine_u8(vget_high_u8(a_r64), vget_low_u8(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet4s preverse(const Packet4s& a) { return vrev64_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet8s preverse(const Packet8s& a)
+{
+ const int16x8_t a_r64 = vrev64q_s16(a);
+ return vcombine_s16(vget_high_s16(a_r64), vget_low_s16(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet4us preverse(const Packet4us& a) { return vrev64_u16(a); }
+template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a)
+{
+ const uint16x8_t a_r64 = vrev64q_u16(a);
+ return vcombine_u16(vget_high_u16(a_r64), vget_low_u16(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet2i preverse(const Packet2i& a) { return vrev64_s32(a); }
+template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
+{
+ const int32x4_t a_r64 = vrev64q_s32(a);
+ return vcombine_s32(vget_high_s32(a_r64), vget_low_s32(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet2ui preverse(const Packet2ui& a) { return vrev64_u32(a); }
+template<> EIGEN_STRONG_INLINE Packet4ui preverse(const Packet4ui& a)
+{
+ const uint32x4_t a_r64 = vrev64q_u32(a);
+ return vcombine_u32(vget_high_u32(a_r64), vget_low_u32(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet2l preverse(const Packet2l& a)
+{ return vcombine_s64(vget_high_s64(a), vget_low_s64(a)); }
+template<> EIGEN_STRONG_INLINE Packet2ul preverse(const Packet2ul& a)
+{ return vcombine_u64(vget_high_u64(a), vget_low_u64(a)); }
+template<> EIGEN_STRONG_INLINE Packet2f pabs(const Packet2f& a) { return vabs_f32(a); }
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vabsq_f32(a); }
+template<> EIGEN_STRONG_INLINE Packet4c pabs<Packet4c>(const Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vabs_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c pabs(const Packet8c& a) { return vabs_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vabsq_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet4uc pabs(const Packet4uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8uc pabs(const Packet8uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4s pabs(const Packet4s& a) { return vabs_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vabsq_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet4us pabs(const Packet4us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2i pabs(const Packet2i& a) { return vabs_s32(a); }
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vabsq_s32(a); }
+template<> EIGEN_STRONG_INLINE Packet2ui pabs(const Packet2ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4ui pabs(const Packet4ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) {
+#if EIGEN_ARCH_ARM64
+ return vabsq_s64(a);
+#else
+ return vcombine_s64(
+ vdup_n_s64((std::abs)(vgetq_lane_s64(a, 0))),
+ vdup_n_s64((std::abs)(vgetq_lane_s64(a, 1))));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent)
+{ return pfrexp_generic(a,exponent); }
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)
+{ return pfrexp_generic(a,exponent); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pldexp<Packet2f>(const Packet2f& a, const Packet2f& exponent)
+{ return pldexp_generic(a,exponent); }
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent)
+{ return pldexp_generic(a,exponent); }
+
+template<> EIGEN_STRONG_INLINE float predux<Packet2f>(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
- float32x2_t a_lo, a_hi, sum;
-
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- sum = vpadd_f32(a_lo, a_hi);
- sum = vpadd_f32(sum, sum);
- return vget_lane_f32(sum, 0);
+ const float32x2_t sum = vadd_f32(vget_low_f32(a), vget_high_f32(a));
+ return vget_lane_f32(vpadd_f32(sum, sum), 0);
}
-
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
+template<> EIGEN_STRONG_INLINE int8_t predux<Packet4c>(const Packet4c& a)
{
- float32x4x2_t vtrn1, vtrn2, res1, res2;
- Packet4f sum1, sum2, sum;
-
- // NEON zip performs interleaving of the supplied vectors.
- // We perform two interleaves in a row to acquire the transposed vector
- vtrn1 = vzipq_f32(vecs[0], vecs[2]);
- vtrn2 = vzipq_f32(vecs[1], vecs[3]);
- res1 = vzipq_f32(vtrn1.val[0], vtrn2.val[0]);
- res2 = vzipq_f32(vtrn1.val[1], vtrn2.val[1]);
-
- // Do the addition of the resulting vectors
- sum1 = vaddq_f32(res1.val[0], res1.val[1]);
- sum2 = vaddq_f32(res2.val[0], res2.val[1]);
- sum = vaddq_f32(sum1, sum2);
-
- return sum;
+ const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
+ int8x8_t sum = vpadd_s8(a_dup, a_dup);
+ sum = vpadd_s8(sum, sum);
+ return vget_lane_s8(sum, 0);
}
-
+template<> EIGEN_STRONG_INLINE int8_t predux<Packet8c>(const Packet8c& a)
+{
+ int8x8_t sum = vpadd_s8(a,a);
+ sum = vpadd_s8(sum, sum);
+ sum = vpadd_s8(sum, sum);
+ return vget_lane_s8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux<Packet16c>(const Packet16c& a)
+{
+ int8x8_t sum = vadd_s8(vget_low_s8(a), vget_high_s8(a));
+ sum = vpadd_s8(sum, sum);
+ sum = vpadd_s8(sum, sum);
+ sum = vpadd_s8(sum, sum);
+ return vget_lane_s8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux<Packet4uc>(const Packet4uc& a)
+{
+ const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t sum = vpadd_u8(a_dup, a_dup);
+ sum = vpadd_u8(sum, sum);
+ return vget_lane_u8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t sum = vpadd_u8(a,a);
+ sum = vpadd_u8(sum, sum);
+ sum = vpadd_u8(sum, sum);
+ return vget_lane_u8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux<Packet16uc>(const Packet16uc& a)
+{
+ uint8x8_t sum = vadd_u8(vget_low_u8(a), vget_high_u8(a));
+ sum = vpadd_u8(sum, sum);
+ sum = vpadd_u8(sum, sum);
+ sum = vpadd_u8(sum, sum);
+ return vget_lane_u8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t sum = vpadd_s16(a,a);
+ return vget_lane_s16(vpadd_s16(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux<Packet8s>(const Packet8s& a)
+{
+ int16x4_t sum = vadd_s16(vget_low_s16(a), vget_high_s16(a));
+ sum = vpadd_s16(sum, sum);
+ sum = vpadd_s16(sum, sum);
+ return vget_lane_s16(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t sum = vpadd_u16(a,a);
+ return vget_lane_u16(vpadd_u16(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t sum = vadd_u16(vget_low_u16(a), vget_high_u16(a));
+ sum = vpadd_u16(sum, sum);
+ sum = vpadd_u16(sum, sum);
+ return vget_lane_u16(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux<Packet2i>(const Packet2i& a) { return vget_lane_s32(vpadd_s32(a,a), 0); }
template<> EIGEN_STRONG_INLINE int32_t predux<Packet4i>(const Packet4i& a)
{
- int32x2_t a_lo, a_hi, sum;
-
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- sum = vpadd_s32(a_lo, a_hi);
- sum = vpadd_s32(sum, sum);
- return vget_lane_s32(sum, 0);
+ const int32x2_t sum = vadd_s32(vget_low_s32(a), vget_high_s32(a));
+ return vget_lane_s32(vpadd_s32(sum, sum), 0);
}
-
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet2ui>(const Packet2ui& a) { return vget_lane_u32(vpadd_u32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a)
{
- int32x4x2_t vtrn1, vtrn2, res1, res2;
- Packet4i sum1, sum2, sum;
-
- // NEON zip performs interleaving of the supplied vectors.
- // We perform two interleaves in a row to acquire the transposed vector
- vtrn1 = vzipq_s32(vecs[0], vecs[2]);
- vtrn2 = vzipq_s32(vecs[1], vecs[3]);
- res1 = vzipq_s32(vtrn1.val[0], vtrn2.val[0]);
- res2 = vzipq_s32(vtrn1.val[1], vtrn2.val[1]);
-
- // Do the addition of the resulting vectors
- sum1 = vaddq_s32(res1.val[0], res1.val[1]);
- sum2 = vaddq_s32(res2.val[0], res2.val[1]);
- sum = vaddq_s32(sum1, sum2);
-
- return sum;
+ const uint32x2_t sum = vadd_u32(vget_low_u32(a), vget_high_u32(a));
+ return vget_lane_u32(vpadd_u32(sum, sum), 0);
}
+template<> EIGEN_STRONG_INLINE int64_t predux<Packet2l>(const Packet2l& a)
+{ return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1); }
+template<> EIGEN_STRONG_INLINE uint64_t predux<Packet2ul>(const Packet2ul& a)
+{ return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c predux_half_dowto4(const Packet8c& a)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(a,
+ vreinterpret_s8_s32(vrev64_s32(vreinterpret_s32_s8(a))))), 0);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c predux_half_dowto4(const Packet16c& a)
+{ return vadd_s8(vget_high_s8(a), vget_low_s8(a)); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc predux_half_dowto4(const Packet8uc& a)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(a,
+ vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(a))))), 0);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc predux_half_dowto4(const Packet16uc& a)
+{ return vadd_u8(vget_high_u8(a), vget_low_u8(a)); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s predux_half_dowto4(const Packet8s& a)
+{ return vadd_s16(vget_high_s16(a), vget_low_s16(a)); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us predux_half_dowto4(const Packet8us& a)
+{ return vadd_u16(vget_high_u16(a), vget_low_u16(a)); }
// Other reduction functions:
// mul
+template<> EIGEN_STRONG_INLINE float predux_mul<Packet2f>(const Packet2f& a)
+{ return vget_lane_f32(a, 0) * vget_lane_f32(a, 1); }
template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
+{ return predux_mul(vmul_f32(vget_low_f32(a), vget_high_f32(a))); }
+template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet4c>(const Packet4c& a)
{
- float32x2_t a_lo, a_hi, prod;
-
- // Get a_lo = |a1|a2| and a_hi = |a3|a4|
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- // Get the product of a_lo * a_hi -> |a1*a3|a2*a4|
- prod = vmul_f32(a_lo, a_hi);
- // Multiply prod with its swapped value |a2*a4|a1*a3|
- prod = vmul_f32(prod, vrev64_f32(prod));
-
- return vget_lane_f32(prod, 0);
+ int8x8_t prod = vreinterpret_s8_s32(vdup_n_s32(a));
+ prod = vmul_s8(prod, vrev16_s8(prod));
+ return vget_lane_s8(prod, 0) * vget_lane_s8(prod, 2);
}
+template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet8c>(const Packet8c& a)
+{
+ int8x8_t prod = vmul_s8(a, vrev16_s8(a));
+ prod = vmul_s8(prod, vrev32_s8(prod));
+ return vget_lane_s8(prod, 0) * vget_lane_s8(prod, 4);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet16c>(const Packet16c& a)
+{ return predux_mul(vmul_s8(vget_low_s8(a), vget_high_s8(a))); }
+template<> EIGEN_STRONG_INLINE uint8_t predux_mul<Packet4uc>(const Packet4uc& a)
+{
+ uint8x8_t prod = vreinterpret_u8_u32(vdup_n_u32(a));
+ prod = vmul_u8(prod, vrev16_u8(prod));
+ return vget_lane_u8(prod, 0) * vget_lane_u8(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_mul<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t prod = vmul_u8(a, vrev16_u8(a));
+ prod = vmul_u8(prod, vrev32_u8(prod));
+ return vget_lane_u8(prod, 0) * vget_lane_u8(prod, 4);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_mul<Packet16uc>(const Packet16uc& a)
+{ return predux_mul(vmul_u8(vget_low_u8(a), vget_high_u8(a))); }
+template<> EIGEN_STRONG_INLINE int16_t predux_mul<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t prod = vmul_s16(a, vrev32_s16(a));
+ return vget_lane_s16(prod, 0) * vget_lane_s16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_mul<Packet8s>(const Packet8s& a)
+{
+ int16x4_t prod;
+
+ // Get the product of a_lo * a_hi -> |a1*a5|a2*a6|a3*a7|a4*a8|
+ prod = vmul_s16(vget_low_s16(a), vget_high_s16(a));
+ // Swap and multiply |a1*a5*a2*a6|a3*a7*a4*a8|
+ prod = vmul_s16(prod, vrev32_s16(prod));
+ // Multiply |a1*a5*a2*a6*a3*a7*a4*a8|
+ return vget_lane_s16(prod, 0) * vget_lane_s16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_mul<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t prod = vmul_u16(a, vrev32_u16(a));
+ return vget_lane_u16(prod, 0) * vget_lane_u16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_mul<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t prod;
+
+ // Get the product of a_lo * a_hi -> |a1*a5|a2*a6|a3*a7|a4*a8|
+ prod = vmul_u16(vget_low_u16(a), vget_high_u16(a));
+ // Swap and multiply |a1*a5*a2*a6|a3*a7*a4*a8|
+ prod = vmul_u16(prod, vrev32_u16(prod));
+ // Multiply |a1*a5*a2*a6*a3*a7*a4*a8|
+ return vget_lane_u16(prod, 0) * vget_lane_u16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux_mul<Packet2i>(const Packet2i& a)
+{ return vget_lane_s32(a, 0) * vget_lane_s32(a, 1); }
template<> EIGEN_STRONG_INLINE int32_t predux_mul<Packet4i>(const Packet4i& a)
-{
- int32x2_t a_lo, a_hi, prod;
-
- // Get a_lo = |a1|a2| and a_hi = |a3|a4|
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- // Get the product of a_lo * a_hi -> |a1*a3|a2*a4|
- prod = vmul_s32(a_lo, a_hi);
- // Multiply prod with its swapped value |a2*a4|a1*a3|
- prod = vmul_s32(prod, vrev64_s32(prod));
-
- return vget_lane_s32(prod, 0);
-}
+{ return predux_mul(vmul_s32(vget_low_s32(a), vget_high_s32(a))); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_mul<Packet2ui>(const Packet2ui& a)
+{ return vget_lane_u32(a, 0) * vget_lane_u32(a, 1); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_mul<Packet4ui>(const Packet4ui& a)
+{ return predux_mul(vmul_u32(vget_low_u32(a), vget_high_u32(a))); }
+template<> EIGEN_STRONG_INLINE int64_t predux_mul<Packet2l>(const Packet2l& a)
+{ return vgetq_lane_s64(a, 0) * vgetq_lane_s64(a, 1); }
+template<> EIGEN_STRONG_INLINE uint64_t predux_mul<Packet2ul>(const Packet2ul& a)
+{ return vgetq_lane_u64(a, 0) * vgetq_lane_u64(a, 1); }
// min
+template<> EIGEN_STRONG_INLINE float predux_min<Packet2f>(const Packet2f& a)
+{ return vget_lane_f32(vpmin_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
{
- float32x2_t a_lo, a_hi, min;
-
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- min = vpmin_f32(a_lo, a_hi);
- min = vpmin_f32(min, min);
-
- return vget_lane_f32(min, 0);
+ const float32x2_t min = vmin_f32(vget_low_f32(a), vget_high_f32(a));
+ return vget_lane_f32(vpmin_f32(min, min), 0);
}
-
+template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet4c>(const Packet4c& a)
+{
+ const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
+ int8x8_t min = vpmin_s8(a_dup, a_dup);
+ min = vpmin_s8(min, min);
+ return vget_lane_s8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet8c>(const Packet8c& a)
+{
+ int8x8_t min = vpmin_s8(a,a);
+ min = vpmin_s8(min, min);
+ min = vpmin_s8(min, min);
+ return vget_lane_s8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet16c>(const Packet16c& a)
+{
+ int8x8_t min = vmin_s8(vget_low_s8(a), vget_high_s8(a));
+ min = vpmin_s8(min, min);
+ min = vpmin_s8(min, min);
+ min = vpmin_s8(min, min);
+ return vget_lane_s8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet4uc>(const Packet4uc& a)
+{
+ const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t min = vpmin_u8(a_dup, a_dup);
+ min = vpmin_u8(min, min);
+ return vget_lane_u8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t min = vpmin_u8(a,a);
+ min = vpmin_u8(min, min);
+ min = vpmin_u8(min, min);
+ return vget_lane_u8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet16uc>(const Packet16uc& a)
+{
+ uint8x8_t min = vmin_u8(vget_low_u8(a), vget_high_u8(a));
+ min = vpmin_u8(min, min);
+ min = vpmin_u8(min, min);
+ min = vpmin_u8(min, min);
+ return vget_lane_u8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_min<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t min = vpmin_s16(a,a);
+ return vget_lane_s16(vpmin_s16(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_min<Packet8s>(const Packet8s& a)
+{
+ int16x4_t min = vmin_s16(vget_low_s16(a), vget_high_s16(a));
+ min = vpmin_s16(min, min);
+ min = vpmin_s16(min, min);
+ return vget_lane_s16(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_min<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t min = vpmin_u16(a,a);
+ return vget_lane_u16(vpmin_u16(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_min<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t min = vmin_u16(vget_low_u16(a), vget_high_u16(a));
+ min = vpmin_u16(min, min);
+ min = vpmin_u16(min, min);
+ return vget_lane_u16(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet2i>(const Packet2i& a)
+{ return vget_lane_s32(vpmin_s32(a,a), 0); }
template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet4i>(const Packet4i& a)
{
- int32x2_t a_lo, a_hi, min;
-
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- min = vpmin_s32(a_lo, a_hi);
- min = vpmin_s32(min, min);
-
- return vget_lane_s32(min, 0);
+ const int32x2_t min = vmin_s32(vget_low_s32(a), vget_high_s32(a));
+ return vget_lane_s32(vpmin_s32(min, min), 0);
}
+template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet2ui>(const Packet2ui& a)
+{ return vget_lane_u32(vpmin_u32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet4ui>(const Packet4ui& a)
+{
+ const uint32x2_t min = vmin_u32(vget_low_u32(a), vget_high_u32(a));
+ return vget_lane_u32(vpmin_u32(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE int64_t predux_min<Packet2l>(const Packet2l& a)
+{ return (std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); }
+template<> EIGEN_STRONG_INLINE uint64_t predux_min<Packet2ul>(const Packet2ul& a)
+{ return (std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); }
// max
+template<> EIGEN_STRONG_INLINE float predux_max<Packet2f>(const Packet2f& a)
+{ return vget_lane_f32(vpmax_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
{
- float32x2_t a_lo, a_hi, max;
-
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- max = vpmax_f32(a_lo, a_hi);
- max = vpmax_f32(max, max);
-
- return vget_lane_f32(max, 0);
+ const float32x2_t max = vmax_f32(vget_low_f32(a), vget_high_f32(a));
+ return vget_lane_f32(vpmax_f32(max, max), 0);
}
-
+template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet4c>(const Packet4c& a)
+{
+ const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
+ int8x8_t max = vpmax_s8(a_dup, a_dup);
+ max = vpmax_s8(max, max);
+ return vget_lane_s8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet8c>(const Packet8c& a)
+{
+ int8x8_t max = vpmax_s8(a,a);
+ max = vpmax_s8(max, max);
+ max = vpmax_s8(max, max);
+ return vget_lane_s8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet16c>(const Packet16c& a)
+{
+ int8x8_t max = vmax_s8(vget_low_s8(a), vget_high_s8(a));
+ max = vpmax_s8(max, max);
+ max = vpmax_s8(max, max);
+ max = vpmax_s8(max, max);
+ return vget_lane_s8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet4uc>(const Packet4uc& a)
+{
+ const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t max = vpmax_u8(a_dup, a_dup);
+ max = vpmax_u8(max, max);
+ return vget_lane_u8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t max = vpmax_u8(a,a);
+ max = vpmax_u8(max, max);
+ max = vpmax_u8(max, max);
+ return vget_lane_u8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet16uc>(const Packet16uc& a)
+{
+ uint8x8_t max = vmax_u8(vget_low_u8(a), vget_high_u8(a));
+ max = vpmax_u8(max, max);
+ max = vpmax_u8(max, max);
+ max = vpmax_u8(max, max);
+ return vget_lane_u8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_max<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t max = vpmax_s16(a,a);
+ return vget_lane_s16(vpmax_s16(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_max<Packet8s>(const Packet8s& a)
+{
+ int16x4_t max = vmax_s16(vget_low_s16(a), vget_high_s16(a));
+ max = vpmax_s16(max, max);
+ max = vpmax_s16(max, max);
+ return vget_lane_s16(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_max<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t max = vpmax_u16(a,a);
+ return vget_lane_u16(vpmax_u16(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_max<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t max = vmax_u16(vget_low_u16(a), vget_high_u16(a));
+ max = vpmax_u16(max, max);
+ max = vpmax_u16(max, max);
+ return vget_lane_u16(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet2i>(const Packet2i& a)
+{ return vget_lane_s32(vpmax_s32(a,a), 0); }
template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet4i>(const Packet4i& a)
{
- int32x2_t a_lo, a_hi, max;
+ const int32x2_t max = vmax_s32(vget_low_s32(a), vget_high_s32(a));
+ return vget_lane_s32(vpmax_s32(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet2ui>(const Packet2ui& a)
+{ return vget_lane_u32(vpmax_u32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet4ui>(const Packet4ui& a)
+{
+ const uint32x2_t max = vmax_u32(vget_low_u32(a), vget_high_u32(a));
+ return vget_lane_u32(vpmax_u32(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE int64_t predux_max<Packet2l>(const Packet2l& a)
+{ return (std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); }
+template<> EIGEN_STRONG_INLINE uint64_t predux_max<Packet2ul>(const Packet2ul& a)
+{ return (std::max)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); }
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- max = vpmax_s32(a_lo, a_hi);
- max = vpmax_s32(max, max);
-
- return vget_lane_s32(max, 0);
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
+{
+ uint32x2_t tmp = vorr_u32(vget_low_u32( vreinterpretq_u32_f32(x)),
+ vget_high_u32(vreinterpretq_u32_f32(x)));
+ return vget_lane_u32(vpmax_u32(tmp, tmp), 0);
}
-// this PALIGN_NEON business is to work around a bug in LLVM Clang 3.0 causing incorrect compilation errors,
-// see bug 347 and this LLVM bug: http://llvm.org/bugs/show_bug.cgi?id=11074
-#define PALIGN_NEON(Offset,Type,Command) \
-template<>\
-struct palign_impl<Offset,Type>\
-{\
- EIGEN_STRONG_INLINE static void run(Type& first, const Type& second)\
- {\
- if (Offset!=0)\
- first = Command(first, second, Offset);\
- }\
-};\
+// Helpers for ptranspose.
+namespace detail {
+
+template<typename Packet>
+void zip_in_place(Packet& p1, Packet& p2);
-PALIGN_NEON(0,Packet4f,vextq_f32)
-PALIGN_NEON(1,Packet4f,vextq_f32)
-PALIGN_NEON(2,Packet4f,vextq_f32)
-PALIGN_NEON(3,Packet4f,vextq_f32)
-PALIGN_NEON(0,Packet4i,vextq_s32)
-PALIGN_NEON(1,Packet4i,vextq_s32)
-PALIGN_NEON(2,Packet4i,vextq_s32)
-PALIGN_NEON(3,Packet4i,vextq_s32)
-
-#undef PALIGN_NEON
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4f,4>& kernel) {
- float32x4x2_t tmp1 = vzipq_f32(kernel.packet[0], kernel.packet[1]);
- float32x4x2_t tmp2 = vzipq_f32(kernel.packet[2], kernel.packet[3]);
-
- kernel.packet[0] = vcombine_f32(vget_low_f32(tmp1.val[0]), vget_low_f32(tmp2.val[0]));
- kernel.packet[1] = vcombine_f32(vget_high_f32(tmp1.val[0]), vget_high_f32(tmp2.val[0]));
- kernel.packet[2] = vcombine_f32(vget_low_f32(tmp1.val[1]), vget_low_f32(tmp2.val[1]));
- kernel.packet[3] = vcombine_f32(vget_high_f32(tmp1.val[1]), vget_high_f32(tmp2.val[1]));
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2f>(Packet2f& p1, Packet2f& p2) {
+ const float32x2x2_t tmp = vzip_f32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4i,4>& kernel) {
- int32x4x2_t tmp1 = vzipq_s32(kernel.packet[0], kernel.packet[1]);
- int32x4x2_t tmp2 = vzipq_s32(kernel.packet[2], kernel.packet[3]);
- kernel.packet[0] = vcombine_s32(vget_low_s32(tmp1.val[0]), vget_low_s32(tmp2.val[0]));
- kernel.packet[1] = vcombine_s32(vget_high_s32(tmp1.val[0]), vget_high_s32(tmp2.val[0]));
- kernel.packet[2] = vcombine_s32(vget_low_s32(tmp1.val[1]), vget_low_s32(tmp2.val[1]));
- kernel.packet[3] = vcombine_s32(vget_high_s32(tmp1.val[1]), vget_high_s32(tmp2.val[1]));
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4f>(Packet4f& p1, Packet4f& p2) {
+ const float32x4x2_t tmp = vzipq_f32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8c>(Packet8c& p1, Packet8c& p2) {
+ const int8x8x2_t tmp = vzip_s8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet16c>(Packet16c& p1, Packet16c& p2) {
+ const int8x16x2_t tmp = vzipq_s8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8uc>(Packet8uc& p1, Packet8uc& p2) {
+ const uint8x8x2_t tmp = vzip_u8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet16uc>(Packet16uc& p1, Packet16uc& p2) {
+ const uint8x16x2_t tmp = vzipq_u8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2i>(Packet2i& p1, Packet2i& p2) {
+ const int32x2x2_t tmp = vzip_s32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4i>(Packet4i& p1, Packet4i& p2) {
+ const int32x4x2_t tmp = vzipq_s32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2ui>(Packet2ui& p1, Packet2ui& p2) {
+ const uint32x2x2_t tmp = vzip_u32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4ui>(Packet4ui& p1, Packet4ui& p2) {
+ const uint32x4x2_t tmp = vzipq_u32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4s>(Packet4s& p1, Packet4s& p2) {
+ const int16x4x2_t tmp = vzip_s16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8s>(Packet8s& p1, Packet8s& p2) {
+ const int16x8x2_t tmp = vzipq_s16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4us>(Packet4us& p1, Packet4us& p2) {
+ const uint16x4x2_t tmp = vzip_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8us>(Packet8us& p1, Packet8us& p2) {
+ const uint16x8x2_t tmp = vzipq_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 2>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 4>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[2]);
+ zip_in_place(kernel.packet[1], kernel.packet[3]);
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+ zip_in_place(kernel.packet[2], kernel.packet[3]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 8>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[4]);
+ zip_in_place(kernel.packet[1], kernel.packet[5]);
+ zip_in_place(kernel.packet[2], kernel.packet[6]);
+ zip_in_place(kernel.packet[3], kernel.packet[7]);
+
+ zip_in_place(kernel.packet[0], kernel.packet[2]);
+ zip_in_place(kernel.packet[1], kernel.packet[3]);
+ zip_in_place(kernel.packet[4], kernel.packet[6]);
+ zip_in_place(kernel.packet[5], kernel.packet[7]);
+
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+ zip_in_place(kernel.packet[2], kernel.packet[3]);
+ zip_in_place(kernel.packet[4], kernel.packet[5]);
+ zip_in_place(kernel.packet[6], kernel.packet[7]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 16>& kernel) {
+ EIGEN_UNROLL_LOOP
+ for (int i=0; i<4; ++i) {
+ const int m = (1 << i);
+ EIGEN_UNROLL_LOOP
+ for (int j=0; j<m; ++j) {
+ const int n = (1 << (3-i));
+ EIGEN_UNROLL_LOOP
+ for (int k=0; k<n; ++k) {
+ const int idx = 2*j*n+k;
+ zip_in_place(kernel.packet[idx], kernel.packet[idx + n]);
+ }
+ }
+ }
+}
+
+} // namespace detail
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2f, 2>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4f, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4c, 4>& kernel)
+{
+ const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1));
+ const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1));
+
+ const int8x8x2_t zip8 = vzip_s8(a,b);
+ const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1]));
+
+ kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0);
+ kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1);
+ kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0);
+ kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 16>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4uc, 4>& kernel)
+{
+ const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1));
+ const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1));
+
+ const uint8x8x2_t zip8 = vzip_u8(a,b);
+ const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1]));
+
+ kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0);
+ kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1);
+ kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0);
+ kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 16>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4s, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4us, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2i, 2>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4i, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2ui, 2>& kernel) {
+ detail::zip_in_place(kernel.packet[0], kernel.packet[1]);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4ui, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet2l, 2>& kernel)
+{
+#if EIGEN_ARCH_ARM64
+ const int64x2_t tmp1 = vzip1q_s64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[1] = vzip2q_s64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[0] = tmp1;
+#else
+ const int64x1_t tmp[2][2] = {
+ { vget_low_s64(kernel.packet[0]), vget_high_s64(kernel.packet[0]) },
+ { vget_low_s64(kernel.packet[1]), vget_high_s64(kernel.packet[1]) }
+ };
+
+ kernel.packet[0] = vcombine_s64(tmp[0][0], tmp[1][0]);
+ kernel.packet[1] = vcombine_s64(tmp[0][1], tmp[1][1]);
+#endif
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet2ul, 2>& kernel)
+{
+#if EIGEN_ARCH_ARM64
+ const uint64x2_t tmp1 = vzip1q_u64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[1] = vzip2q_u64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[0] = tmp1;
+#else
+ const uint64x1_t tmp[2][2] = {
+ { vget_low_u64(kernel.packet[0]), vget_high_u64(kernel.packet[0]) },
+ { vget_low_u64(kernel.packet[1]), vget_high_u64(kernel.packet[1]) }
+ };
+
+ kernel.packet[0] = vcombine_u64(tmp[0][0], tmp[1][0]);
+ kernel.packet[1] = vcombine_u64(tmp[0][1], tmp[1][1]);
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pselect( const Packet2f& mask, const Packet2f& a, const Packet2f& b)
+{ return vbsl_f32(vreinterpret_u32_f32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b)
+{ return vbslq_f32(vreinterpretq_u32_f32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pselect(const Packet8c& mask, const Packet8c& a, const Packet8c& b)
+{ return vbsl_s8(vreinterpret_u8_s8(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pselect(const Packet16c& mask, const Packet16c& a, const Packet16c& b)
+{ return vbslq_s8(vreinterpretq_u8_s8(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pselect(const Packet8uc& mask, const Packet8uc& a, const Packet8uc& b)
+{ return vbsl_u8(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pselect(const Packet16uc& mask, const Packet16uc& a, const Packet16uc& b)
+{ return vbslq_u8(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pselect(const Packet4s& mask, const Packet4s& a, const Packet4s& b)
+{ return vbsl_s16(vreinterpret_u16_s16(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pselect(const Packet8s& mask, const Packet8s& a, const Packet8s& b)
+{ return vbslq_s16(vreinterpretq_u16_s16(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pselect(const Packet4us& mask, const Packet4us& a, const Packet4us& b)
+{ return vbsl_u16(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pselect(const Packet8us& mask, const Packet8us& a, const Packet8us& b)
+{ return vbslq_u16(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pselect(const Packet2i& mask, const Packet2i& a, const Packet2i& b)
+{ return vbsl_s32(vreinterpret_u32_s32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b)
+{ return vbslq_s32(vreinterpretq_u32_s32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pselect(const Packet2ui& mask, const Packet2ui& a, const Packet2ui& b)
+{ return vbsl_u32(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pselect(const Packet4ui& mask, const Packet4ui& a, const Packet4ui& b)
+{ return vbslq_u32(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pselect(const Packet2l& mask, const Packet2l& a, const Packet2l& b)
+{ return vbslq_s64(vreinterpretq_u64_s64(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b)
+{ return vbslq_u64(mask, a, b); }
+
+// Use armv8 rounding intinsics if available.
+#if EIGEN_ARCH_ARMV8
+template<> EIGEN_STRONG_INLINE Packet2f print<Packet2f>(const Packet2f& a)
+{ return vrndn_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a)
+{ return vrndnq_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
+{ return vrndm_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{ return vrndmq_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{ return vrndp_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{ return vrndpq_f32(a); }
+
+#else
+
+template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet4f limit = pset1<Packet4f>(static_cast<float>(1<<23));
+ const Packet4f abs_a = pabs(a);
+ Packet4f r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f print(const Packet2f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet2f limit = pset1<Packet2f>(static_cast<float>(1<<23));
+ const Packet2f abs_a = pabs(a);
+ Packet2f r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If greater, subtract one.
+ Packet4f mask = pcmp_lt(a, tmp);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0f);
+ Packet2f tmp = print<Packet2f>(a);
+ // If greater, subtract one.
+ Packet2f mask = pcmp_lt(a, tmp);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If smaller, add one.
+ Packet4f mask = pcmp_lt(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0);
+ Packet2f tmp = print<Packet2f>(a);
+ // If smaller, add one.
+ Packet2f mask = pcmp_lt(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+
+#endif
+
+/**
+ * Computes the integer square root
+ * @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result
+ * and tests whether setting that digit to 1 would cause the square of the value to be greater than the argument
+ * value. The algorithm is described in detail here: http://ww1.microchip.com/downloads/en/AppNotes/91040a.pdf .
+ */
+template<> EIGEN_STRONG_INLINE Packet4uc psqrt(const Packet4uc& a) {
+ uint8x8_t x = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t res = vdup_n_u8(0);
+ uint8x8_t add = vdup_n_u8(0x8);
+ for (int i = 0; i < 4; i++)
+ {
+ const uint8x8_t temp = vorr_u8(res, add);
+ res = vbsl_u8(vcge_u8(x, vmul_u8(temp, temp)), temp, res);
+ add = vshr_n_u8(add, 1);
+ }
+ return vget_lane_u32(vreinterpret_u32_u8(res), 0);
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet8uc psqrt(const Packet8uc& a) {
+ uint8x8_t res = vdup_n_u8(0);
+ uint8x8_t add = vdup_n_u8(0x8);
+ for (int i = 0; i < 4; i++)
+ {
+ const uint8x8_t temp = vorr_u8(res, add);
+ res = vbsl_u8(vcge_u8(a, vmul_u8(temp, temp)), temp, res);
+ add = vshr_n_u8(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet16uc psqrt(const Packet16uc& a) {
+ uint8x16_t res = vdupq_n_u8(0);
+ uint8x16_t add = vdupq_n_u8(0x8);
+ for (int i = 0; i < 4; i++)
+ {
+ const uint8x16_t temp = vorrq_u8(res, add);
+ res = vbslq_u8(vcgeq_u8(a, vmulq_u8(temp, temp)), temp, res);
+ add = vshrq_n_u8(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet4us psqrt(const Packet4us& a) {
+ uint16x4_t res = vdup_n_u16(0);
+ uint16x4_t add = vdup_n_u16(0x80);
+ for (int i = 0; i < 8; i++)
+ {
+ const uint16x4_t temp = vorr_u16(res, add);
+ res = vbsl_u16(vcge_u16(a, vmul_u16(temp, temp)), temp, res);
+ add = vshr_n_u16(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet8us psqrt(const Packet8us& a) {
+ uint16x8_t res = vdupq_n_u16(0);
+ uint16x8_t add = vdupq_n_u16(0x80);
+ for (int i = 0; i < 8; i++)
+ {
+ const uint16x8_t temp = vorrq_u16(res, add);
+ res = vbslq_u16(vcgeq_u16(a, vmulq_u16(temp, temp)), temp, res);
+ add = vshrq_n_u16(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet2ui psqrt(const Packet2ui& a) {
+ uint32x2_t res = vdup_n_u32(0);
+ uint32x2_t add = vdup_n_u32(0x8000);
+ for (int i = 0; i < 16; i++)
+ {
+ const uint32x2_t temp = vorr_u32(res, add);
+ res = vbsl_u32(vcge_u32(a, vmul_u32(temp, temp)), temp, res);
+ add = vshr_n_u32(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
+ uint32x4_t res = vdupq_n_u32(0);
+ uint32x4_t add = vdupq_n_u32(0x8000);
+ for (int i = 0; i < 16; i++)
+ {
+ const uint32x4_t temp = vorrq_u32(res, add);
+ res = vbslq_u32(vcgeq_u32(a, vmulq_u32(temp, temp)), temp, res);
+ add = vshrq_n_u32(add, 1);
+ }
+ return res;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet4f x = vrsqrteq_f32(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
+ x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
+ const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet2f x = vrsqrte_f32(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
+ x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
+ const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+// Unfortunately vsqrt_f32 is only available for A64.
+#if EIGEN_ARCH_ARM64
+template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);}
+template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); }
+#else
+template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
+ const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
+ const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
+ return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
+}
+template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) {
+ const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
+ const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
+ return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
+}
+#endif
+
+//---------- bfloat16 ----------
+// TODO: Add support for native armv8.6-a bfloat16_t
+
+// TODO: Guard if we have native bfloat16 support
+typedef eigen_packet_wrapper<uint16x4_t, 19> Packet4bf;
+
+template<> struct is_arithmetic<Packet4bf> { enum { value = true }; };
+
+template<> struct packet_traits<bfloat16> : default_packet_traits
+{
+ typedef Packet4bf type;
+ typedef Packet4bf half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBessel = 0, // Issues with accuracy.
+ HasNdtri = 0
+ };
+};
+
+template<> struct unpacket_traits<Packet4bf>
+{
+ typedef bfloat16 type;
+ typedef Packet4bf half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+namespace detail {
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4bf>(Packet4bf& p1, Packet4bf& p2) {
+ const uint16x4x2_t tmp = vzip_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+} // namespace detail
+
+EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p)
+{
+ // See the scalar implemention in BFloat16.h for a comprehensible explanation
+ // of this fast rounding algorithm
+ Packet4ui input = reinterpret_cast<Packet4ui>(p);
+
+ // lsb = (input >> 16) & 1
+ Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1));
+
+ // rounding_bias = 0x7fff + lsb
+ Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff));
+
+ // input += rounding_bias
+ input = vaddq_u32(input, rounding_bias);
+
+ // input = input >> 16
+ input = vshrq_n_u32(input, 16);
+
+ // Replace float-nans by bfloat16-nans, that is 0x7fc0
+ const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0);
+ const Packet4ui mask = vceqq_f32(p, p);
+ input = vbslq_u32(mask, input, bf16_nan);
+
+ // output = static_cast<uint16_t>(input)
+ return vmovn_u32(input);
+}
+
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
+{
+ return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
+}
+
+EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) {
+ return vmovn_u32(vreinterpretq_u32_f32(p));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
+ return pset1<Packet4us>(from.value);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) {
+ return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<uint16_t>(pfirst<Packet4us>(from)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from)
+{
+ return pload<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from)
+{
+ return ploadu<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from)
+{
+ EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet4bf& from)
+{
+ EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf ploaddup<Packet4bf>(const bfloat16* from)
+{
+ return ploaddup<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) {
+ return F32ToBf16(pabs<Packet4f>(Bf16ToF32(a)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<PropagateNumbers, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<PropagateNumbers, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<PropagateNaN, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<PropagateNaN, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<PropagateNumbers, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<PropagateNumbers, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<PropagateNaN, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<PropagateNaN, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf plset<Packet4bf>(const bfloat16& a)
+{
+ return F32ToBf16(plset<Packet4f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) {
+ return por<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) {
+ return pxor<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) {
+ return pand<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) {
+ return pandnot<Packet4us>(a, b);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a,
+ const Packet4bf& b)
+{
+ return pselect<Packet4us>(mask, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf print<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(print<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pceil<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(pceil<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(padd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf psub<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(psub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet4bf pgather<bfloat16, Packet4bf>(const bfloat16* from, Index stride)
+{
+ return pgather<uint16_t, Packet4us>(reinterpret_cast<const uint16_t*>(from), stride);
+}
+
+template<>
+EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet4bf>(bfloat16* to, const Packet4bf& from, Index stride)
+{
+ pscatter<uint16_t, Packet4us>(reinterpret_cast<uint16_t*>(to), from, stride);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_max<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_min<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_mul<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a)
+{
+ return preverse<Packet4us>(a);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4bf, 4>& kernel)
+{
+ detail::ptranspose_impl(kernel);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32ToBf16(pabsdiff<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt_or_nan<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_lt_or_nan<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)
+{
+ return pxor<Packet4us>(a, pset1<Packet4us>(static_cast<uint16_t>(0x8000)));
}
//---------- double ----------
@@ -571,55 +3642,115 @@
// Defining these functions as templates ensures that if these intrinsics are
// already defined in arm_neon.h, then our workaround doesn't cause a conflict
// and has lower priority in overload resolution.
-template <typename T>
-uint64x2_t vreinterpretq_u64_f64(T a)
-{
- return (uint64x2_t) a;
-}
+template <typename T> uint64x2_t vreinterpretq_u64_f64(T a) { return (uint64x2_t) a; }
-template <typename T>
-float64x2_t vreinterpretq_f64_u64(T a)
-{
- return (float64x2_t) a;
-}
+template <typename T> float64x2_t vreinterpretq_f64_u64(T a) { return (float64x2_t) a; }
typedef float64x2_t Packet2d;
typedef float64x1_t Packet1d;
+// fuctionally equivalent to _mm_shuffle_pd in SSE (i.e. shuffle(m, n, mask) equals _mm_shuffle_pd(m,n,mask))
+// Currently used in LU/arch/InverseSize4.h to enable a shared implementation
+// for fast inversion of matrices of size 4.
+EIGEN_STRONG_INLINE Packet2d shuffle(const Packet2d& m, const Packet2d& n, int mask)
+{
+ const double* a = reinterpret_cast<const double*>(&m);
+ const double* b = reinterpret_cast<const double*>(&n);
+ Packet2d res = {*(a + (mask & 1)), *(b + ((mask >> 1) & 1))};
+ return res;
+}
+
+EIGEN_STRONG_INLINE Packet2d vec2d_swizzle2(const Packet2d& a, const Packet2d& b, int mask)
+{
+ return shuffle(a, b, mask);
+}
+EIGEN_STRONG_INLINE Packet2d vec2d_unpacklo(const Packet2d& a,const Packet2d& b)
+{
+ return shuffle(a, b, 0);
+}
+EIGEN_STRONG_INLINE Packet2d vec2d_unpackhi(const Packet2d& a,const Packet2d& b)
+{
+ return shuffle(a, b, 3);
+}
+#define vec2d_duplane(a, p) \
+ vdupq_laneq_f64(a, p)
+
template<> struct packet_traits<double> : default_packet_traits
{
typedef Packet2d type;
typedef Packet2d half;
- enum {
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 2,
- HasHalfPacket=0,
-
- HasDiv = 1,
- // FIXME check the Has*
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+
HasSin = 0,
HasCos = 0,
- HasLog = 0,
- HasExp = 0,
- HasSqrt = 0
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = 0,
+ HasErf = 0
};
};
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet2d>
+{
+ typedef double type;
+ typedef Packet2d half;
+ typedef Packet2l integer_packet;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) { return vdupq_n_f64(from); }
template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a)
{
- const double countdown_raw[] = {0.0,1.0};
- const Packet2d countdown = vld1q_f64(countdown_raw);
- return vaddq_f64(pset1<Packet2d>(a), countdown);
+ const double c[] = {0.0,1.0};
+ return vaddq_f64(pset1<Packet2d>(a), vld1q_f64(c));
}
+
template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return vaddq_f64(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return vsubq_f64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& , const Packet2d& );
+template<> EIGEN_STRONG_INLINE Packet2d paddsub<Packet2d>(const Packet2d& a, const Packet2d& b){
+ const Packet2d mask = {numext::bit_cast<double>(0x8000000000000000ull),0.0};
+ return padd(a, pxor(mask, b));
+}
+
template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return vnegq_f64(a); }
template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
@@ -630,128 +3761,824 @@
#ifdef __ARM_FEATURE_FMA
// See bug 936. See above comment about FMA for float.
-template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vfmaq_f64(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c)
+{ return vfmaq_f64(c,a,b); }
#else
-template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vmlaq_f64(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c)
+{ return vmlaq_f64(c,a,b); }
#endif
template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return vminq_f64(a,b); }
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet2d pmin<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) { return vminnmq_f64(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmax<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) { return vmaxnmq_f64(a, b); }
+
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet2d pmin<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) { return pmin<Packet2d>(a, b); }
+
template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vmaxq_f64(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pmax<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) { return pmax<Packet2d>(a, b); }
+
// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
-template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u64(vcleq_f64(a,b)); }
-template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f64(from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u64(vcltq_f64(a,b)); }
-template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
-{
- return vld1q_dup_f64(from);
-}
-template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_ALIGNED_STORE vst1q_f64(to, from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u32(vmvnq_u32(vreinterpretq_u32_u64(vcgeq_f64(a,b)))); }
-template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE vst1q_f64(to, from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u64(vceqq_f64(a,b)); }
-template<> EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const double* from, Index stride)
+template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }
+
+template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f64(from); }
+
+template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from) { return vld1q_dup_f64(from); }
+template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_f64(to,from); }
+
+template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f64(to,from); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pgather<double, Packet2d>(const double* from, Index stride)
{
Packet2d res = pset1<Packet2d>(0.0);
- res = vsetq_lane_f64(from[0*stride], res, 0);
- res = vsetq_lane_f64(from[1*stride], res, 1);
+ res = vld1q_lane_f64(from + 0*stride, res, 0);
+ res = vld1q_lane_f64(from + 1*stride, res, 1);
return res;
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
{
- to[stride*0] = vgetq_lane_f64(from, 0);
- to[stride*1] = vgetq_lane_f64(from, 1);
+ vst1q_lane_f64(to + stride*0, from, 0);
+ vst1q_lane_f64(to + stride*1, from, 1);
}
+
template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { EIGEN_ARM_PREFETCH(addr); }
// FIXME only store the 2 first elements ?
-template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(a, 0); }
+template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(a,0); }
-template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { return vcombine_f64(vget_high_f64(a), vget_low_f64(a)); }
+template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
+{ return vcombine_f64(vget_high_f64(a), vget_low_f64(a)); }
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vabsq_f64(a); }
#if EIGEN_COMP_CLANG && defined(__apple_build_version__)
// workaround ICE, see bug 907
-template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a) { return (vget_low_f64(a) + vget_high_f64(a))[0]; }
+template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
+{ return (vget_low_f64(a) + vget_high_f64(a))[0]; }
#else
-template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a) { return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0); }
+template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
+{ return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0); }
#endif
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- float64x2_t trn1, trn2;
-
- // NEON zip performs interleaving of the supplied vectors.
- // We perform two interleaves in a row to acquire the transposed vector
- trn1 = vzip1q_f64(vecs[0], vecs[1]);
- trn2 = vzip2q_f64(vecs[0], vecs[1]);
-
- // Do the addition of the resulting vectors
- return vaddq_f64(trn1, trn2);
-}
// Other reduction functions:
// mul
#if EIGEN_COMP_CLANG && defined(__apple_build_version__)
-template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a) { return (vget_low_f64(a) * vget_high_f64(a))[0]; }
+template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
+{ return (vget_low_f64(a) * vget_high_f64(a))[0]; }
#else
-template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a) { return vget_lane_f64(vget_low_f64(a) * vget_high_f64(a), 0); }
+template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
+{ return vget_lane_f64(vget_low_f64(a) * vget_high_f64(a), 0); }
#endif
// min
-template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(vpminq_f64(a, a), 0); }
+template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a)
+{ return vgetq_lane_f64(vpminq_f64(a,a), 0); }
// max
-template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(vpmaxq_f64(a, a), 0); }
+template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
+{ return vgetq_lane_f64(vpmaxq_f64(a,a), 0); }
-// this PALIGN_NEON business is to work around a bug in LLVM Clang 3.0 causing incorrect compilation errors,
-// see bug 347 and this LLVM bug: http://llvm.org/bugs/show_bug.cgi?id=11074
-#define PALIGN_NEON(Offset,Type,Command) \
-template<>\
-struct palign_impl<Offset,Type>\
-{\
- EIGEN_STRONG_INLINE static void run(Type& first, const Type& second)\
- {\
- if (Offset!=0)\
- first = Command(first, second, Offset);\
- }\
-};\
-PALIGN_NEON(0,Packet2d,vextq_f64)
-PALIGN_NEON(1,Packet2d,vextq_f64)
-#undef PALIGN_NEON
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet2d, 2>& kernel)
+{
+ const float64x2_t tmp1 = vzip1q_f64(kernel.packet[0], kernel.packet[1]);
+ const float64x2_t tmp2 = vzip2q_f64(kernel.packet[0], kernel.packet[1]);
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet2d,2>& kernel) {
- float64x2_t trn1 = vzip1q_f64(kernel.packet[0], kernel.packet[1]);
- float64x2_t trn2 = vzip2q_f64(kernel.packet[0], kernel.packet[1]);
-
- kernel.packet[0] = trn1;
- kernel.packet[1] = trn2;
+ kernel.packet[0] = tmp1;
+ kernel.packet[1] = tmp2;
}
-#endif // EIGEN_ARCH_ARM64
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b)
+{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a)
+{ return vrndnq_f64(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
+{ return vrndmq_f64(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a)
+{ return vrndpq_f64(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
+{ return pldexp_generic(a, exponent); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent)
+{ return pfrexp_generic(a,exponent); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
+{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet2d x = vrsqrteq_f64(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ const Packet2d infinity = pset1<Packet2d>(NumTraits<double>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); }
+
+#endif // EIGEN_ARCH_ARM64
+
+// Do we have an fp16 types and supporting Neon intrinsics?
+#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+typedef float16x4_t Packet4hf;
+typedef float16x8_t Packet8hf;
+
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet8hf type;
+ typedef Packet4hf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasCmp = 1,
+ HasCast = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasInsert = 1,
+ HasReduxp = 1,
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasSin = 0,
+ HasCos = 0,
+ HasLog = 0,
+ HasExp = 0,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasErf = EIGEN_FAST_MATH,
+ HasBessel = 0, // Issues with accuracy.
+ HasNdtri = 0
+ };
+};
+
+template <>
+struct unpacket_traits<Packet4hf> {
+ typedef Eigen::half type;
+ typedef Packet4hf half;
+ enum {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template <>
+struct unpacket_traits<Packet8hf> {
+ typedef Eigen::half type;
+ typedef Packet4hf half;
+ enum {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf predux_half_dowto4<Packet8hf>(const Packet8hf& a) {
+ return vadd_f16(vget_low_f16(a), vget_high_f16(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pset1<Packet8hf>(const Eigen::half& from) {
+ return vdupq_n_f16(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pset1<Packet4hf>(const Eigen::half& from) {
+ return vdup_n_f16(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf plset<Packet8hf>(const Eigen::half& a) {
+ const float16_t f[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ Packet8hf countdown = vld1q_f16(f);
+ return vaddq_f16(pset1<Packet8hf>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf plset<Packet4hf>(const Eigen::half& a) {
+ const float16_t f[] = {0, 1, 2, 3};
+ Packet4hf countdown = vld1_f16(f);
+ return vadd_f16(pset1<Packet4hf>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf padd<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vaddq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf padd<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vadd_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf psub<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vsubq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf psub<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vsub_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pnegate(const Packet8hf& a) {
+ return vnegq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pnegate(const Packet4hf& a) {
+ return vneg_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pconj(const Packet8hf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pconj(const Packet4hf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmul<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vmulq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmul<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmul_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pdiv<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vdivq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pdiv<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vdiv_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
+ return vfmaq_f16(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
+ return vfma_f16(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vminq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmin<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmin_f16(a, b);
+}
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4hf pmin<PropagateNumbers, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return vminnm_f16(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8hf pmin<PropagateNumbers, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return vminnmq_f16(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4hf pmin<PropagateNaN, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return pmin<Packet4hf>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet8hf pmin<PropagateNaN, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return pmin<Packet8hf>(a, b); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmax<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vmaxq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmax<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmax_f16(a, b);
+}
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4hf pmax<PropagateNumbers, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return vmaxnm_f16(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8hf pmax<PropagateNumbers, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return vmaxnmq_f16(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4hf pmax<PropagateNaN, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return pmax<Packet4hf>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet8hf pmax<PropagateNaN, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return pmax<Packet8hf>(a, b); }
+
+#define EIGEN_MAKE_ARM_FP16_CMP_8(name) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet8hf pcmp_##name(const Packet8hf& a, const Packet8hf& b) { \
+ return vreinterpretq_f16_u16(vc##name##q_f16(a, b)); \
+ }
+
+#define EIGEN_MAKE_ARM_FP16_CMP_4(name) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet4hf pcmp_##name(const Packet4hf& a, const Packet4hf& b) { \
+ return vreinterpret_f16_u16(vc##name##_f16(a, b)); \
+ }
+
+EIGEN_MAKE_ARM_FP16_CMP_8(eq)
+EIGEN_MAKE_ARM_FP16_CMP_8(lt)
+EIGEN_MAKE_ARM_FP16_CMP_8(le)
+
+EIGEN_MAKE_ARM_FP16_CMP_4(eq)
+EIGEN_MAKE_ARM_FP16_CMP_4(lt)
+EIGEN_MAKE_ARM_FP16_CMP_4(le)
+
+#undef EIGEN_MAKE_ARM_FP16_CMP_8
+#undef EIGEN_MAKE_ARM_FP16_CMP_4
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pcmp_lt_or_nan<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vmvnq_u16(vcgeq_f16(a, b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf print<Packet8hf>(const Packet8hf& a)
+{ return vrndnq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf print<Packet4hf>(const Packet4hf& a)
+{ return vrndn_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a)
+{ return vrndmq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a)
+{ return vrndm_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a)
+{ return vrndpq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a)
+{ return vrndp_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
+ return vsqrtq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf psqrt<Packet4hf>(const Packet4hf& a) {
+ return vsqrt_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pand<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pand<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf por<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf por<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pxor<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pxor<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(veor_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pandnot<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pandnot<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vbic_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pload<Packet8hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pload<Packet4hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploadu<Packet8hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ploadu<Packet4hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploaddup<Packet8hf>(const Eigen::half* from) {
+ Packet8hf packet;
+ packet[0] = from[0].x;
+ packet[1] = from[0].x;
+ packet[2] = from[1].x;
+ packet[3] = from[1].x;
+ packet[4] = from[2].x;
+ packet[5] = from[2].x;
+ packet[6] = from[3].x;
+ packet[7] = from[3].x;
+ return packet;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ploaddup<Packet4hf>(const Eigen::half* from) {
+ float16x4_t packet;
+ float16_t* tmp;
+ tmp = (float16_t*)&packet;
+ tmp[0] = from[0].x;
+ tmp[1] = from[0].x;
+ tmp[2] = from[1].x;
+ tmp[3] = from[1].x;
+ return packet;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploadquad<Packet8hf>(const Eigen::half* from) {
+ Packet4hf lo, hi;
+ lo = vld1_dup_f16(reinterpret_cast<const float16_t*>(from));
+ hi = vld1_dup_f16(reinterpret_cast<const float16_t*>(from+1));
+ return vcombine_f16(lo, hi);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pinsertfirst(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 0); }
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertfirst(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 0); }
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pselect(const Packet8hf& mask, const Packet8hf& a, const Packet8hf& b) {
+ return vbslq_f16(vreinterpretq_u16_f16(mask), a, b);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pselect(const Packet4hf& mask, const Packet4hf& a, const Packet4hf& b) {
+ return vbsl_f16(vreinterpret_u16_f16(mask), a, b);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pinsertlast(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 7); }
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertlast(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 3); }
+
+template <>
+EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
+ EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
+ EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pgather<Eigen::half, Packet8hf>(const Eigen::half* from, Index stride) {
+ Packet8hf res = pset1<Packet8hf>(Eigen::half(0.f));
+ res = vsetq_lane_f16(from[0 * stride].x, res, 0);
+ res = vsetq_lane_f16(from[1 * stride].x, res, 1);
+ res = vsetq_lane_f16(from[2 * stride].x, res, 2);
+ res = vsetq_lane_f16(from[3 * stride].x, res, 3);
+ res = vsetq_lane_f16(from[4 * stride].x, res, 4);
+ res = vsetq_lane_f16(from[5 * stride].x, res, 5);
+ res = vsetq_lane_f16(from[6 * stride].x, res, 6);
+ res = vsetq_lane_f16(from[7 * stride].x, res, 7);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pgather<Eigen::half, Packet4hf>(const Eigen::half* from, Index stride) {
+ Packet4hf res = pset1<Packet4hf>(Eigen::half(0.f));
+ res = vset_lane_f16(from[0 * stride].x, res, 0);
+ res = vset_lane_f16(from[1 * stride].x, res, 1);
+ res = vset_lane_f16(from[2 * stride].x, res, 2);
+ res = vset_lane_f16(from[3 * stride].x, res, 3);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8hf>(Eigen::half* to, const Packet8hf& from, Index stride) {
+ to[stride * 0].x = vgetq_lane_f16(from, 0);
+ to[stride * 1].x = vgetq_lane_f16(from, 1);
+ to[stride * 2].x = vgetq_lane_f16(from, 2);
+ to[stride * 3].x = vgetq_lane_f16(from, 3);
+ to[stride * 4].x = vgetq_lane_f16(from, 4);
+ to[stride * 5].x = vgetq_lane_f16(from, 5);
+ to[stride * 6].x = vgetq_lane_f16(from, 6);
+ to[stride * 7].x = vgetq_lane_f16(from, 7);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4hf>(Eigen::half* to, const Packet4hf& from, Index stride) {
+ to[stride * 0].x = vget_lane_f16(from, 0);
+ to[stride * 1].x = vget_lane_f16(from, 1);
+ to[stride * 2].x = vget_lane_f16(from, 2);
+ to[stride * 3].x = vget_lane_f16(from, 3);
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<Eigen::half>(const Eigen::half* addr) {
+ EIGEN_ARM_PREFETCH(addr);
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8hf>(const Packet8hf& a) {
+ float16_t x[8];
+ vst1q_f16(x, a);
+ Eigen::half h;
+ h.x = x[0];
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4hf>(const Packet4hf& a) {
+ float16_t x[4];
+ vst1_f16(x, a);
+ Eigen::half h;
+ h.x = x[0];
+ return h;
+}
+
+template<> EIGEN_STRONG_INLINE Packet8hf preverse(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi;
+ Packet8hf a_r64;
+
+ a_r64 = vrev64q_f16(a);
+ a_lo = vget_low_f16(a_r64);
+ a_hi = vget_high_f16(a_r64);
+ return vcombine_f16(a_hi, a_lo);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf preverse<Packet4hf>(const Packet4hf& a) {
+ return vrev64_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pabs<Packet8hf>(const Packet8hf& a) {
+ return vabsq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
+ return vabs_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, sum;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ sum = vpadd_f16(a_lo, a_hi);
+ sum = vpadd_f16(sum, sum);
+ sum = vpadd_f16(sum, sum);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(sum, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux<Packet4hf>(const Packet4hf& a) {
+ float16x4_t sum;
+
+ sum = vpadd_f16(a, a);
+ sum = vpadd_f16(sum, sum);
+ Eigen::half h;
+ h.x = vget_lane_f16(sum, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, prod;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ prod = vmul_f16(a_lo, a_hi);
+ prod = vmul_f16(prod, vrev64_f16(prod));
+
+ Eigen::half h;
+ h.x = vmulh_f16(vget_lane_f16(prod, 0), vget_lane_f16(prod, 1));
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4hf>(const Packet4hf& a) {
+ float16x4_t prod;
+ prod = vmul_f16(a, vrev64_f16(a));
+ Eigen::half h;
+ h.x = vmulh_f16(vget_lane_f16(prod, 0), vget_lane_f16(prod, 1));
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, min;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ min = vpmin_f16(a_lo, a_hi);
+ min = vpmin_f16(min, min);
+ min = vpmin_f16(min, min);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(min, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4hf>(const Packet4hf& a) {
+ Packet4hf tmp;
+ tmp = vpmin_f16(a, a);
+ tmp = vpmin_f16(tmp, tmp);
+ Eigen::half h;
+ h.x = vget_lane_f16(tmp, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, max;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ max = vpmax_f16(a_lo, a_hi);
+ max = vpmax_f16(max, max);
+ max = vpmax_f16(max, max);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(max, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4hf>(const Packet4hf& a) {
+ Packet4hf tmp;
+ tmp = vpmax_f16(a, a);
+ tmp = vpmax_f16(tmp, tmp);
+ Eigen::half h;
+ h.x = vget_lane_f16(tmp, 0);
+ return h;
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8hf, 4>& kernel)
+{
+ const float16x8x2_t zip16_1 = vzipq_f16(kernel.packet[0], kernel.packet[1]);
+ const float16x8x2_t zip16_2 = vzipq_f16(kernel.packet[2], kernel.packet[3]);
+
+ const float32x4x2_t zip32_1 = vzipq_f32(vreinterpretq_f32_f16(zip16_1.val[0]), vreinterpretq_f32_f16(zip16_2.val[0]));
+ const float32x4x2_t zip32_2 = vzipq_f32(vreinterpretq_f32_f16(zip16_1.val[1]), vreinterpretq_f32_f16(zip16_2.val[1]));
+
+ kernel.packet[0] = vreinterpretq_f16_f32(zip32_1.val[0]);
+ kernel.packet[1] = vreinterpretq_f16_f32(zip32_1.val[1]);
+ kernel.packet[2] = vreinterpretq_f16_f32(zip32_2.val[0]);
+ kernel.packet[3] = vreinterpretq_f16_f32(zip32_2.val[1]);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4hf, 4>& kernel) {
+ EIGEN_ALIGN16 float16x4x4_t tmp_x4;
+ float16_t* tmp = (float16_t*)&kernel;
+ tmp_x4 = vld4_f16(tmp);
+
+ kernel.packet[0] = tmp_x4.val[0];
+ kernel.packet[1] = tmp_x4.val[1];
+ kernel.packet[2] = tmp_x4.val[2];
+ kernel.packet[3] = tmp_x4.val[3];
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8hf, 8>& kernel) {
+ float16x8x2_t T_1[4];
+
+ T_1[0] = vuzpq_f16(kernel.packet[0], kernel.packet[1]);
+ T_1[1] = vuzpq_f16(kernel.packet[2], kernel.packet[3]);
+ T_1[2] = vuzpq_f16(kernel.packet[4], kernel.packet[5]);
+ T_1[3] = vuzpq_f16(kernel.packet[6], kernel.packet[7]);
+
+ float16x8x2_t T_2[4];
+ T_2[0] = vuzpq_f16(T_1[0].val[0], T_1[1].val[0]);
+ T_2[1] = vuzpq_f16(T_1[0].val[1], T_1[1].val[1]);
+ T_2[2] = vuzpq_f16(T_1[2].val[0], T_1[3].val[0]);
+ T_2[3] = vuzpq_f16(T_1[2].val[1], T_1[3].val[1]);
+
+ float16x8x2_t T_3[4];
+ T_3[0] = vuzpq_f16(T_2[0].val[0], T_2[2].val[0]);
+ T_3[1] = vuzpq_f16(T_2[0].val[1], T_2[2].val[1]);
+ T_3[2] = vuzpq_f16(T_2[1].val[0], T_2[3].val[0]);
+ T_3[3] = vuzpq_f16(T_2[1].val[1], T_2[3].val[1]);
+
+ kernel.packet[0] = T_3[0].val[0];
+ kernel.packet[1] = T_3[2].val[0];
+ kernel.packet[2] = T_3[1].val[0];
+ kernel.packet[3] = T_3[3].val[0];
+ kernel.packet[4] = T_3[0].val[1];
+ kernel.packet[5] = T_3[2].val[1];
+ kernel.packet[6] = T_3[1].val[1];
+ kernel.packet[7] = T_3[3].val[1];
+}
+#endif // end EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/TypeCasting.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/TypeCasting.h
new file mode 100644
index 0000000..54f9733
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/NEON/TypeCasting.h
@@ -0,0 +1,1419 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2018 Rasmus Munk Larsen <rmlarsen@google.com>
+// Copyright (C) 2020 Antonio Sanchez <cantonios@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_NEON_H
+#define EIGEN_TYPE_CASTING_NEON_H
+
+namespace Eigen {
+
+namespace internal {
+
+//==============================================================================
+// pcast, SrcType = float
+//==============================================================================
+template <>
+struct type_casting_traits<float, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet4f, Packet4f>(const Packet4f& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet2f, Packet2f>(const Packet2f& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<float, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+struct type_casting_traits<float, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+// If float64 exists, first convert to that to keep as much precision as possible.
+#if EIGEN_ARCH_ARM64
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4f, Packet2l>(const Packet4f& a) {
+ // Discard second half of input.
+ return vcvtq_s64_f64(vcvt_f64_f32(vget_low_f32(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4f, Packet2ul>(const Packet4f& a) {
+ // Discard second half of input.
+ return vcvtq_u64_f64(vcvt_f64_f32(vget_low_f32(a)));
+}
+#else
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4f, Packet2l>(const Packet4f& a) {
+ // Discard second half of input.
+ return vmovl_s32(vget_low_s32(vcvtq_s32_f32(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4f, Packet2ul>(const Packet4f& a) {
+ // Discard second half of input.
+ return vmovl_u32(vget_low_u32(vcvtq_u32_f32(a)));
+}
+#endif // EIGEN_ARCH_ARM64
+
+template <>
+struct type_casting_traits<float, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
+ return vcvtq_s32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet2f, Packet2i>(const Packet2f& a) {
+ return vcvt_s32_f32(a);
+}
+
+template <>
+struct type_casting_traits<float, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
+ return vcvtq_u32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet2f, Packet2ui>(const Packet2f& a) {
+ return vcvt_u32_f32(a);
+}
+
+template <>
+struct type_casting_traits<float, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet4f, Packet8s>(const Packet4f& a, const Packet4f& b) {
+ return vcombine_s16(vmovn_s32(vcvtq_s32_f32(a)), vmovn_s32(vcvtq_s32_f32(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet2f, Packet4s>(const Packet2f& a, const Packet2f& b) {
+ return vmovn_s32(vcombine_s32(vcvt_s32_f32(a), vcvt_s32_f32(b)));
+}
+
+template <>
+struct type_casting_traits<float, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet4f, Packet8us>(const Packet4f& a, const Packet4f& b) {
+ return vcombine_u16(vmovn_u32(vcvtq_u32_f32(a)), vmovn_u32(vcvtq_u32_f32(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet2f, Packet4us>(const Packet2f& a, const Packet2f& b) {
+ return vmovn_u32(vcombine_u32(vcvt_u32_f32(a), vcvt_u32_f32(b)));
+}
+
+template <>
+struct type_casting_traits<float, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet4f, Packet16c>(const Packet4f& a, const Packet4f& b, const Packet4f& c,
+ const Packet4f& d) {
+ const int16x8_t ab_s16 = pcast<Packet4f, Packet8s>(a, b);
+ const int16x8_t cd_s16 = pcast<Packet4f, Packet8s>(c, d);
+ return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet2f, Packet8c>(const Packet2f& a, const Packet2f& b, const Packet2f& c,
+ const Packet2f& d) {
+ const int16x4_t ab_s16 = pcast<Packet2f, Packet4s>(a, b);
+ const int16x4_t cd_s16 = pcast<Packet2f, Packet4s>(c, d);
+ return vmovn_s16(vcombine_s16(ab_s16, cd_s16));
+}
+
+template <>
+struct type_casting_traits<float, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet4f, Packet16uc>(const Packet4f& a, const Packet4f& b, const Packet4f& c,
+ const Packet4f& d) {
+ const uint16x8_t ab_u16 = pcast<Packet4f, Packet8us>(a, b);
+ const uint16x8_t cd_u16 = pcast<Packet4f, Packet8us>(c, d);
+ return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet2f, Packet8uc>(const Packet2f& a, const Packet2f& b, const Packet2f& c,
+ const Packet2f& d) {
+ const uint16x4_t ab_u16 = pcast<Packet2f, Packet4us>(a, b);
+ const uint16x4_t cd_u16 = pcast<Packet2f, Packet4us>(c, d);
+ return vmovn_u16(vcombine_u16(ab_u16, cd_u16));
+}
+
+//==============================================================================
+// pcast, SrcType = int8_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int8_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet16c, Packet4f>(const Packet16c& a) {
+ // Discard all but first 4 bytes.
+ return vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a)))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet8c, Packet2f>(const Packet8c& a) {
+ // Discard all but first 2 bytes.
+ return vcvt_f32_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a)))));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet16c, Packet2l>(const Packet16c& a) {
+ // Discard all but first two bytes.
+ return vmovl_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a))))));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet16c, Packet2ul>(const Packet16c& a) {
+ return vreinterpretq_u64_s64(pcast<Packet16c, Packet2l>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet16c, Packet4i>(const Packet16c& a) {
+ // Discard all but first 4 bytes.
+ return vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet8c, Packet2i>(const Packet8c& a) {
+ // Discard all but first 2 bytes.
+ return vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a))));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet16c, Packet4ui>(const Packet16c& a) {
+ return vreinterpretq_u32_s32(pcast<Packet16c, Packet4i>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet8c, Packet2ui>(const Packet8c& a) {
+ return vreinterpret_u32_s32(pcast<Packet8c, Packet2i>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet16c, Packet8s>(const Packet16c& a) {
+ // Discard second half of input.
+ return vmovl_s8(vget_low_s8(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet8c, Packet4s>(const Packet8c& a) {
+ // Discard second half of input.
+ return vget_low_s16(vmovl_s8(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet16c, Packet8us>(const Packet16c& a) {
+ return vreinterpretq_u16_s16(pcast<Packet16c, Packet8s>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet8c, Packet4us>(const Packet8c& a) {
+ return vreinterpret_u16_s16(pcast<Packet8c, Packet4s>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet16c, Packet16c>(const Packet16c& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet8c, Packet8c>(const Packet8c& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4c pcast<Packet4c, Packet4c>(const Packet4c& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet16c, Packet16uc>(const Packet16c& a) {
+ return vreinterpretq_u8_s8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet8c, Packet8uc>(const Packet8c& a) {
+ return vreinterpret_u8_s8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4uc pcast<Packet4c, Packet4uc>(const Packet4c& a) {
+ return static_cast<Packet4uc>(a);
+}
+
+//==============================================================================
+// pcast, SrcType = uint8_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint8_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet16uc, Packet4f>(const Packet16uc& a) {
+ // Discard all but first 4 bytes.
+ return vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a)))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet8uc, Packet2f>(const Packet8uc& a) {
+ // Discard all but first 2 bytes.
+ return vcvt_f32_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a)))));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet16uc, Packet2ul>(const Packet16uc& a) {
+ // Discard all but first two bytes.
+ return vmovl_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a))))));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet16uc, Packet2l>(const Packet16uc& a) {
+ return vreinterpretq_s64_u64(pcast<Packet16uc, Packet2ul>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet16uc, Packet4ui>(const Packet16uc& a) {
+ // Discard all but first 4 bytes.
+ return vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet8uc, Packet2ui>(const Packet8uc& a) {
+ // Discard all but first 2 bytes.
+ return vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a))));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet16uc, Packet4i>(const Packet16uc& a) {
+ return vreinterpretq_s32_u32(pcast<Packet16uc, Packet4ui>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet8uc, Packet2i>(const Packet8uc& a) {
+ return vreinterpret_s32_u32(pcast<Packet8uc, Packet2ui>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet16uc, Packet8us>(const Packet16uc& a) {
+ // Discard second half of input.
+ return vmovl_u8(vget_low_u8(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet8uc, Packet4us>(const Packet8uc& a) {
+ // Discard second half of input.
+ return vget_low_u16(vmovl_u8(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet16uc, Packet8s>(const Packet16uc& a) {
+ return vreinterpretq_s16_u16(pcast<Packet16uc, Packet8us>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet8uc, Packet4s>(const Packet8uc& a) {
+ return vreinterpret_s16_u16(pcast<Packet8uc, Packet4us>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet16uc, Packet16uc>(const Packet16uc& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet8uc, Packet8uc>(const Packet8uc& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4uc pcast<Packet4uc, Packet4uc>(const Packet4uc& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet16uc, Packet16c>(const Packet16uc& a) {
+ return vreinterpretq_s8_u8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet8uc, Packet8c>(const Packet8uc& a) {
+ return vreinterpret_s8_u8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4c pcast<Packet4uc, Packet4c>(const Packet4uc& a) {
+ return static_cast<Packet4c>(a);
+}
+
+//==============================================================================
+// pcast, SrcType = int16_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int16_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet8s, Packet4f>(const Packet8s& a) {
+ // Discard second half of input.
+ return vcvtq_f32_s32(vmovl_s16(vget_low_s16(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet4s, Packet2f>(const Packet4s& a) {
+ // Discard second half of input.
+ return vcvt_f32_s32(vget_low_s32(vmovl_s16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet8s, Packet2l>(const Packet8s& a) {
+ // Discard all but first two values.
+ return vmovl_s32(vget_low_s32(vmovl_s16(vget_low_s16(a))));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet8s, Packet2ul>(const Packet8s& a) {
+ return vreinterpretq_u64_s64(pcast<Packet8s, Packet2l>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet8s, Packet4i>(const Packet8s& a) {
+ // Discard second half of input.
+ return vmovl_s16(vget_low_s16(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet4s, Packet2i>(const Packet4s& a) {
+ // Discard second half of input.
+ return vget_low_s32(vmovl_s16(a));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet8s, Packet4ui>(const Packet8s& a) {
+ return vreinterpretq_u32_s32(pcast<Packet8s, Packet4i>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet4s, Packet2ui>(const Packet4s& a) {
+ return vreinterpret_u32_s32(pcast<Packet4s, Packet2i>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet8s, Packet8s>(const Packet8s& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet4s, Packet4s>(const Packet4s& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet8s, Packet8us>(const Packet8s& a) {
+ return vreinterpretq_u16_s16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet4s, Packet4us>(const Packet4s& a) {
+ return vreinterpret_u16_s16(a);
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet8s, Packet16c>(const Packet8s& a, const Packet8s& b) {
+ return vcombine_s8(vmovn_s16(a), vmovn_s16(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet4s, Packet8c>(const Packet4s& a, const Packet4s& b) {
+ return vmovn_s16(vcombine_s16(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet8s, Packet16uc>(const Packet8s& a, const Packet8s& b) {
+ return vcombine_u8(vmovn_u16(vreinterpretq_u16_s16(a)), vmovn_u16(vreinterpretq_u16_s16(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet4s, Packet8uc>(const Packet4s& a, const Packet4s& b) {
+ return vmovn_u16(vcombine_u16(vreinterpret_u16_s16(a), vreinterpret_u16_s16(b)));
+}
+
+//==============================================================================
+// pcast, SrcType = uint16_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint16_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet8us, Packet4f>(const Packet8us& a) {
+ // Discard second half of input.
+ return vcvtq_f32_u32(vmovl_u16(vget_low_u16(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet4us, Packet2f>(const Packet4us& a) {
+ // Discard second half of input.
+ return vcvt_f32_u32(vget_low_u32(vmovl_u16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet8us, Packet2ul>(const Packet8us& a) {
+ // Discard all but first two values.
+ return vmovl_u32(vget_low_u32(vmovl_u16(vget_low_u16(a))));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet8us, Packet2l>(const Packet8us& a) {
+ return vreinterpretq_s64_u64(pcast<Packet8us, Packet2ul>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet8us, Packet4ui>(const Packet8us& a) {
+ // Discard second half of input.
+ return vmovl_u16(vget_low_u16(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet4us, Packet2ui>(const Packet4us& a) {
+ // Discard second half of input.
+ return vget_low_u32(vmovl_u16(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet8us, Packet4i>(const Packet8us& a) {
+ return vreinterpretq_s32_u32(pcast<Packet8us, Packet4ui>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet4us, Packet2i>(const Packet4us& a) {
+ return vreinterpret_s32_u32(pcast<Packet4us, Packet2ui>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet8us, Packet8us>(const Packet8us& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet4us, Packet4us>(const Packet4us& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet8us, Packet8s>(const Packet8us& a) {
+ return vreinterpretq_s16_u16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet4us, Packet4s>(const Packet4us& a) {
+ return vreinterpret_s16_u16(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet8us, Packet16uc>(const Packet8us& a, const Packet8us& b) {
+ return vcombine_u8(vmovn_u16(a), vmovn_u16(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet4us, Packet8uc>(const Packet4us& a, const Packet4us& b) {
+ return vmovn_u16(vcombine_u16(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet8us, Packet16c>(const Packet8us& a, const Packet8us& b) {
+ return vreinterpretq_s8_u8(pcast<Packet8us, Packet16uc>(a, b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet4us, Packet8c>(const Packet4us& a, const Packet4us& b) {
+ return vreinterpret_s8_u8(pcast<Packet4us, Packet8uc>(a, b));
+}
+
+//==============================================================================
+// pcast, SrcType = int32_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int32_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
+ return vcvtq_f32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet2i, Packet2f>(const Packet2i& a) {
+ return vcvt_f32_s32(a);
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4i, Packet2l>(const Packet4i& a) {
+ // Discard second half of input.
+ return vmovl_s32(vget_low_s32(a));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4i, Packet2ul>(const Packet4i& a) {
+ return vreinterpretq_u64_s64(pcast<Packet4i, Packet2l>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet4i, Packet4i>(const Packet4i& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet2i, Packet2i>(const Packet2i& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet4i, Packet4ui>(const Packet4i& a) {
+ return vreinterpretq_u32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet2i, Packet2ui>(const Packet2i& a) {
+ return vreinterpret_u32_s32(a);
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet4i, Packet8s>(const Packet4i& a, const Packet4i& b) {
+ return vcombine_s16(vmovn_s32(a), vmovn_s32(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet2i, Packet4s>(const Packet2i& a, const Packet2i& b) {
+ return vmovn_s32(vcombine_s32(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet4i, Packet8us>(const Packet4i& a, const Packet4i& b) {
+ return vcombine_u16(vmovn_u32(vreinterpretq_u32_s32(a)), vmovn_u32(vreinterpretq_u32_s32(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet2i, Packet4us>(const Packet2i& a, const Packet2i& b) {
+ return vmovn_u32(vreinterpretq_u32_s32(vcombine_s32(a, b)));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet4i, Packet16c>(const Packet4i& a, const Packet4i& b, const Packet4i& c,
+ const Packet4i& d) {
+ const int16x8_t ab_s16 = pcast<Packet4i, Packet8s>(a, b);
+ const int16x8_t cd_s16 = pcast<Packet4i, Packet8s>(c, d);
+ return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet2i, Packet8c>(const Packet2i& a, const Packet2i& b, const Packet2i& c,
+ const Packet2i& d) {
+ const int16x4_t ab_s16 = vmovn_s32(vcombine_s32(a, b));
+ const int16x4_t cd_s16 = vmovn_s32(vcombine_s32(c, d));
+ return vmovn_s16(vcombine_s16(ab_s16, cd_s16));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet4i, Packet16uc>(const Packet4i& a, const Packet4i& b, const Packet4i& c,
+ const Packet4i& d) {
+ const uint16x8_t ab_u16 = pcast<Packet4i, Packet8us>(a, b);
+ const uint16x8_t cd_u16 = pcast<Packet4i, Packet8us>(c, d);
+ return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet2i, Packet8uc>(const Packet2i& a, const Packet2i& b, const Packet2i& c,
+ const Packet2i& d) {
+ const uint16x4_t ab_u16 = pcast<Packet2i, Packet4us>(a, b);
+ const uint16x4_t cd_u16 = pcast<Packet2i, Packet4us>(c, d);
+ return vmovn_u16(vcombine_u16(ab_u16, cd_u16));
+}
+
+//==============================================================================
+// pcast, SrcType = uint32_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint32_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
+ return vcvtq_f32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet2ui, Packet2f>(const Packet2ui& a) {
+ return vcvt_f32_u32(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4ui, Packet2ul>(const Packet4ui& a) {
+ // Discard second half of input.
+ return vmovl_u32(vget_low_u32(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4ui, Packet2l>(const Packet4ui& a) {
+ return vreinterpretq_s64_u64(pcast<Packet4ui, Packet2ul>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet4ui, Packet4ui>(const Packet4ui& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet2ui, Packet2ui>(const Packet2ui& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet4ui, Packet4i>(const Packet4ui& a) {
+ return vreinterpretq_s32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet2ui, Packet2i>(const Packet2ui& a) {
+ return vreinterpret_s32_u32(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet4ui, Packet8us>(const Packet4ui& a, const Packet4ui& b) {
+ return vcombine_u16(vmovn_u32(a), vmovn_u32(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet2ui, Packet4us>(const Packet2ui& a, const Packet2ui& b) {
+ return vmovn_u32(vcombine_u32(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet4ui, Packet8s>(const Packet4ui& a, const Packet4ui& b) {
+ return vreinterpretq_s16_u16(pcast<Packet4ui, Packet8us>(a, b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet2ui, Packet4s>(const Packet2ui& a, const Packet2ui& b) {
+ return vreinterpret_s16_u16(pcast<Packet2ui, Packet4us>(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet4ui, Packet16uc>(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c,
+ const Packet4ui& d) {
+ const uint16x8_t ab_u16 = vcombine_u16(vmovn_u32(a), vmovn_u32(b));
+ const uint16x8_t cd_u16 = vcombine_u16(vmovn_u32(c), vmovn_u32(d));
+ return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet2ui, Packet8uc>(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c,
+ const Packet2ui& d) {
+ const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(a, b));
+ const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(c, d));
+ return vmovn_u16(vcombine_u16(ab_u16, cd_u16));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet4ui, Packet16c>(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c,
+ const Packet4ui& d) {
+ return vreinterpretq_s8_u8(pcast<Packet4ui, Packet16uc>(a, b, c, d));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet2ui, Packet8c>(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c,
+ const Packet2ui& d) {
+ return vreinterpret_s8_u8(pcast<Packet2ui, Packet8uc>(a, b, c, d));
+}
+
+//==============================================================================
+// pcast, SrcType = int64_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int64_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet2l, Packet4f>(const Packet2l& a, const Packet2l& b) {
+ return vcvtq_f32_s32(vcombine_s32(vmovn_s64(a), vmovn_s64(b)));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet2l, Packet2l>(const Packet2l& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet2l, Packet2ul>(const Packet2l& a) {
+ return vreinterpretq_u64_s64(a);
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet2l, Packet4i>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s32(vmovn_s64(a), vmovn_s64(b));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet2l, Packet4ui>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_u32(vmovn_u64(vreinterpretq_u64_s64(a)), vmovn_u64(vreinterpretq_u64_s64(b)));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet2l, Packet8s>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d) {
+ const int32x4_t ab_s32 = pcast<Packet2l, Packet4i>(a, b);
+ const int32x4_t cd_s32 = pcast<Packet2l, Packet4i>(c, d);
+ return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet2l, Packet8us>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d) {
+ const uint32x4_t ab_u32 = pcast<Packet2l, Packet4ui>(a, b);
+ const uint32x4_t cd_u32 = pcast<Packet2l, Packet4ui>(c, d);
+ return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet2l, Packet16c>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d, const Packet2l& e, const Packet2l& f,
+ const Packet2l& g, const Packet2l& h) {
+ const int16x8_t abcd_s16 = pcast<Packet2l, Packet8s>(a, b, c, d);
+ const int16x8_t efgh_s16 = pcast<Packet2l, Packet8s>(e, f, g, h);
+ return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet2l, Packet16uc>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d, const Packet2l& e, const Packet2l& f,
+ const Packet2l& g, const Packet2l& h) {
+ const uint16x8_t abcd_u16 = pcast<Packet2l, Packet8us>(a, b, c, d);
+ const uint16x8_t efgh_u16 = pcast<Packet2l, Packet8us>(e, f, g, h);
+ return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16));
+}
+
+//==============================================================================
+// pcast, SrcType = uint64_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint64_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet2ul, Packet4f>(const Packet2ul& a, const Packet2ul& b) {
+ return vcvtq_f32_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b)));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet2ul, Packet2ul>(const Packet2ul& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet2ul, Packet2l>(const Packet2ul& a) {
+ return vreinterpretq_s64_u64(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet2ul, Packet4ui>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u32(vmovn_u64(a), vmovn_u64(b));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet2ul, Packet4i>(const Packet2ul& a, const Packet2ul& b) {
+ return vreinterpretq_s32_u32(pcast<Packet2ul, Packet4ui>(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet2ul, Packet8us>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d) {
+ const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b)));
+ const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(vmovn_u64(c), vmovn_u64(d)));
+ return vcombine_u16(ab_u16, cd_u16);
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet2ul, Packet8s>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d) {
+ return vreinterpretq_s16_u16(pcast<Packet2ul, Packet8us>(a, b, c, d));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet2ul, Packet16uc>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d, const Packet2ul& e, const Packet2ul& f,
+ const Packet2ul& g, const Packet2ul& h) {
+ const uint16x8_t abcd_u16 = pcast<Packet2ul, Packet8us>(a, b, c, d);
+ const uint16x8_t efgh_u16 = pcast<Packet2ul, Packet8us>(e, f, g, h);
+ return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet2ul, Packet16c>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d, const Packet2ul& e, const Packet2ul& f,
+ const Packet2ul& g, const Packet2ul& h) {
+ return vreinterpretq_s8_u8(pcast<Packet2ul, Packet16uc>(a, b, c, d, e, f, g, h));
+}
+
+//==============================================================================
+// preinterpret
+//==============================================================================
+template <>
+EIGEN_STRONG_INLINE Packet2f preinterpret<Packet2f, Packet2i>(const Packet2i& a) {
+ return vreinterpret_f32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f preinterpret<Packet2f, Packet2ui>(const Packet2ui& a) {
+ return vreinterpret_f32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet4i>(const Packet4i& a) {
+ return vreinterpretq_f32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet4ui>(const Packet4ui& a) {
+ return vreinterpretq_f32_u32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4c preinterpret<Packet4c, Packet4uc>(const Packet4uc& a) {
+ return static_cast<Packet4c>(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c preinterpret<Packet8c, Packet8uc>(const Packet8uc& a) {
+ return vreinterpret_s8_u8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16c preinterpret<Packet16c, Packet16uc>(const Packet16uc& a) {
+ return vreinterpretq_s8_u8(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4uc preinterpret<Packet4uc, Packet4c>(const Packet4c& a) {
+ return static_cast<Packet4uc>(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc preinterpret<Packet8uc, Packet8c>(const Packet8c& a) {
+ return vreinterpret_u8_s8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16uc preinterpret<Packet16uc, Packet16c>(const Packet16c& a) {
+ return vreinterpretq_u8_s8(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4s preinterpret<Packet4s, Packet4us>(const Packet4us& a) {
+ return vreinterpret_s16_u16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8s preinterpret<Packet8s, Packet8us>(const Packet8us& a) {
+ return vreinterpretq_s16_u16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4us preinterpret<Packet4us, Packet4s>(const Packet4s& a) {
+ return vreinterpret_u16_s16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8us preinterpret<Packet8us, Packet8s>(const Packet8s& a) {
+ return vreinterpretq_u16_s16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2i preinterpret<Packet2i, Packet2f>(const Packet2f& a) {
+ return vreinterpret_s32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i preinterpret<Packet2i, Packet2ui>(const Packet2ui& a) {
+ return vreinterpret_s32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet4f>(const Packet4f& a) {
+ return vreinterpretq_s32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet4ui>(const Packet4ui& a) {
+ return vreinterpretq_s32_u32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2ui preinterpret<Packet2ui, Packet2f>(const Packet2f& a) {
+ return vreinterpret_u32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui preinterpret<Packet2ui, Packet2i>(const Packet2i& a) {
+ return vreinterpret_u32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet4f>(const Packet4f& a) {
+ return vreinterpretq_u32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet4i>(const Packet4i& a) {
+ return vreinterpretq_u32_s32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet2ul>(const Packet2ul& a) {
+ return vreinterpretq_s64_u64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul preinterpret<Packet2ul, Packet2l>(const Packet2l& a) {
+ return vreinterpretq_u64_s64(a);
+}
+
+#if EIGEN_ARCH_ARM64
+
+//==============================================================================
+// pcast/preinterpret, Double
+//==============================================================================
+
+template <>
+struct type_casting_traits<double, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet2d, Packet2d>(const Packet2d& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<double, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet2d, Packet4f>(const Packet2d& a, const Packet2d& b) {
+ return vcombine_f32(vcvt_f32_f64(a), vcvt_f32_f64(b));
+}
+
+template <>
+struct type_casting_traits<double, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet2d, Packet2l>(const Packet2d& a) {
+ return vcvtq_s64_f64(a);
+}
+
+template <>
+struct type_casting_traits<double, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet2d, Packet2ul>(const Packet2d& a) {
+ return vcvtq_u64_f64(a);
+}
+
+template <>
+struct type_casting_traits<double, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet2d, Packet4i>(const Packet2d& a, const Packet2d& b) {
+ return vcombine_s32(vmovn_s64(vcvtq_s64_f64(a)), vmovn_s64(vcvtq_s64_f64(b)));
+}
+
+template <>
+struct type_casting_traits<double, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet2d, Packet4ui>(const Packet2d& a, const Packet2d& b) {
+ return vcombine_u32(vmovn_u64(vcvtq_u64_f64(a)), vmovn_u64(vcvtq_u64_f64(b)));
+}
+
+template <>
+struct type_casting_traits<double, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet2d, Packet8s>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d) {
+ const int32x4_t ab_s32 = pcast<Packet2d, Packet4i>(a, b);
+ const int32x4_t cd_s32 = pcast<Packet2d, Packet4i>(c, d);
+ return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32));
+}
+
+template <>
+struct type_casting_traits<double, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet2d, Packet8us>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d) {
+ const uint32x4_t ab_u32 = pcast<Packet2d, Packet4ui>(a, b);
+ const uint32x4_t cd_u32 = pcast<Packet2d, Packet4ui>(c, d);
+ return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32));
+}
+
+template <>
+struct type_casting_traits<double, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet2d, Packet16c>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d, const Packet2d& e, const Packet2d& f,
+ const Packet2d& g, const Packet2d& h) {
+ const int16x8_t abcd_s16 = pcast<Packet2d, Packet8s>(a, b, c, d);
+ const int16x8_t efgh_s16 = pcast<Packet2d, Packet8s>(e, f, g, h);
+ return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16));
+}
+
+template <>
+struct type_casting_traits<double, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet2d, Packet16uc>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d, const Packet2d& e, const Packet2d& f,
+ const Packet2d& g, const Packet2d& h) {
+ const uint16x8_t abcd_u16 = pcast<Packet2d, Packet8us>(a, b, c, d);
+ const uint16x8_t efgh_u16 = pcast<Packet2d, Packet8us>(e, f, g, h);
+ return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16));
+}
+
+template <>
+struct type_casting_traits<float, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f& a) {
+ // Discard second-half of input.
+ return vcvt_f64_f32(vget_low_f32(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet16c, Packet2d>(const Packet16c& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet8c, Packet2f>(vget_low_s8(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet16uc, Packet2d>(const Packet16uc& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet8uc, Packet2f>(vget_low_u8(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet8s, Packet2d>(const Packet8s& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet4s, Packet2f>(vget_low_s16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet8us, Packet2d>(const Packet8us& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet4us, Packet2f>(vget_low_u16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet4i, Packet2d>(const Packet4i& a) {
+ // Discard second half of input.
+ return vcvtq_f64_s64(vmovl_s32(vget_low_s32(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet4ui, Packet2d>(const Packet4ui& a) {
+ // Discard second half of input.
+ return vcvtq_f64_u64(vmovl_u32(vget_low_u32(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet2l, Packet2d>(const Packet2l& a) {
+ return vcvtq_f64_s64(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet2ul, Packet2d>(const Packet2ul& a) {
+ return vcvtq_f64_u64(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet2l>(const Packet2l& a) {
+ return vreinterpretq_f64_s64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet2ul>(const Packet2ul& a) {
+ return vreinterpretq_f64_u64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet2d>(const Packet2d& a) {
+ return vreinterpretq_s64_f64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul preinterpret<Packet2ul, Packet2d>(const Packet2d& a) {
+ return vreinterpretq_u64_f64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4i>(const Packet4i& a) {
+ return vreinterpretq_f64_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet2d>(const Packet2d& a) {
+ return vreinterpretq_s32_f64(a);
+}
+
+#endif // EIGEN_ARCH_ARM64
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_NEON_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/Complex.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/Complex.h
index d075043..8fe22da 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/Complex.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/Complex.h
@@ -19,7 +19,7 @@
{
EIGEN_STRONG_INLINE Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const __m128& a) : v(a) {}
- __m128 v;
+ Packet4f v;
};
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
@@ -40,20 +40,33 @@
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0,
- HasBlend = 1
+ HasBlend = 1
};
};
#endif
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> {
+ typedef std::complex<float> type;
+ typedef Packet2cf half;
+ typedef Packet4f as_real;
+ enum {
+ size=2,
+ alignment=Aligned16,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_sub_ps(a.v,b.v)); }
+
template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a)
{
const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x80000000,0x80000000,0x80000000));
@@ -82,10 +95,11 @@
#endif
}
+template<> EIGEN_STRONG_INLINE Packet2cf ptrue <Packet2cf>(const Packet2cf& a) { return Packet2cf(ptrue(Packet4f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_and_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_or_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_xor_ps(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_andnot_ps(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_andnot_ps(b.v,a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pload <Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>(&numext::real_ref(*from))); }
template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>(&numext::real_ref(*from))); }
@@ -93,19 +107,13 @@
template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
Packet2cf res;
-#if EIGEN_GNUC_AT_MOST(4,2)
- // Workaround annoying "may be used uninitialized in this function" warning with gcc 4.2
- res.v = _mm_loadl_pi(_mm_set1_ps(0.0f), reinterpret_cast<const __m64*>(&from));
-#elif EIGEN_GNUC_AT_LEAST(4,6)
- // Suppress annoying "may be used uninitialized in this function" warning with gcc >= 4.6
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wuninitialized"
- res.v = _mm_loadl_pi(res.v, (const __m64*)&from);
- #pragma GCC diagnostic pop
+#ifdef EIGEN_VECTORIZE_SSE3
+ res.v = _mm_castpd_ps(_mm_loaddup_pd(reinterpret_cast<double const*>(&from)));
#else
- res.v = _mm_loadl_pi(res.v, (const __m64*)&from);
+ res.v = _mm_castpd_ps(_mm_load_sd(reinterpret_cast<double const*>(&from)));
+ res.v = _mm_movelh_ps(res.v, res.v);
#endif
- return Packet2cf(_mm_movelh_ps(res.v,res.v));
+ return res;
}
template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) { return pset1<Packet2cf>(*from); }
@@ -152,105 +160,34 @@
return pfirst(Packet2cf(_mm_add_ps(a.v, _mm_movehl_ps(a.v,a.v))));
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
-{
- return Packet2cf(_mm_add_ps(_mm_movelh_ps(vecs[0].v,vecs[1].v), _mm_movehl_ps(vecs[1].v,vecs[0].v)));
-}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
return pfirst(pmul(a, Packet2cf(_mm_movehl_ps(a.v,a.v))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
-{
- static EIGEN_STRONG_INLINE void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset==1)
- {
- first.v = _mm_movehl_ps(first.v, first.v);
- first.v = _mm_movelh_ps(first.v, second.v);
- }
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(a, pconj(b));
- #else
- const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000));
- return Packet2cf(_mm_add_ps(_mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v), mask),
- _mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3),
- vec4f_swizzle1(b.v, 1, 0, 3, 2))));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(pconj(a), b);
- #else
- const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000));
- return Packet2cf(_mm_add_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v),
- _mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3),
- vec4f_swizzle1(b.v, 1, 0, 3, 2)), mask)));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return pconj(internal::pmul(a, b));
- #else
- const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000));
- return Packet2cf(_mm_sub_ps(_mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v), mask),
- _mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3),
- vec4f_swizzle1(b.v, 1, 0, 3, 2))));
- #endif
- }
-};
-
-EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
-
-template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- // TODO optimize it for SSE3 and 4
- Packet2cf res = conj_helper<Packet2cf,Packet2cf,false,true>().pmul(a,b);
- __m128 s = _mm_mul_ps(b.v,b.v);
- return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(s), 0xb1)))));
-}
-
EIGEN_STRONG_INLINE Packet2cf pcplxflip/* <Packet2cf> */(const Packet2cf& x)
{
return Packet2cf(vec4f_swizzle1(x.v, 1, 0, 3, 2));
}
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
+
+template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{
+ // TODO optimize it for SSE3 and 4
+ Packet2cf res = pmul(a, pconj(b));
+ __m128 s = _mm_mul_ps(b.v,b.v);
+ return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,vec4f_swizzle1(s, 1, 0, 3, 2))));
+}
+
+
//---------- double ----------
struct Packet1cd
{
EIGEN_STRONG_INLINE Packet1cd() {}
EIGEN_STRONG_INLINE explicit Packet1cd(const __m128d& a) : v(a) {}
- __m128d v;
+ Packet2d v;
};
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
@@ -271,6 +208,7 @@
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -280,7 +218,18 @@
};
#endif
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> {
+ typedef std::complex<double> type;
+ typedef Packet1cd half;
+ typedef Packet2d as_real;
+ enum {
+ size=1,
+ alignment=Aligned16,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_sub_pd(a.v,b.v)); }
@@ -305,10 +254,11 @@
#endif
}
+template<> EIGEN_STRONG_INLINE Packet1cd ptrue <Packet1cd>(const Packet1cd& a) { return Packet1cd(ptrue(Packet2d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_and_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_or_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_xor_pd(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_andnot_pd(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_andnot_pd(b.v,a.v)); }
// FIXME force unaligned load, this is a temporary fix
template<> EIGEN_STRONG_INLINE Packet1cd pload <Packet1cd>(const std::complex<double>* from)
@@ -340,86 +290,17 @@
return pfirst(a);
}
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs)
-{
- return vecs[0];
-}
-
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a)
{
return pfirst(a);
}
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(a, pconj(b));
- #else
- const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
- return Packet1cd(_mm_add_pd(_mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v), mask),
- _mm_mul_pd(vec2d_swizzle1(a.v, 1, 1),
- vec2d_swizzle1(b.v, 1, 0))));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(pconj(a), b);
- #else
- const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
- return Packet1cd(_mm_add_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v),
- _mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 1, 1),
- vec2d_swizzle1(b.v, 1, 0)), mask)));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return pconj(internal::pmul(a, b));
- #else
- const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
- return Packet1cd(_mm_sub_pd(_mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v), mask),
- _mm_mul_pd(vec2d_swizzle1(a.v, 1, 1),
- vec2d_swizzle1(b.v, 1, 0))));
- #endif
- }
-};
-
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for SSE3 and 4
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
+ Packet1cd res = pmul(a,pconj(b));
__m128d s = _mm_mul_pd(b.v,b.v);
return Packet1cd(_mm_div_pd(res.v, _mm_add_pd(s,_mm_shuffle_pd(s, s, 0x1))));
}
@@ -439,33 +320,32 @@
kernel.packet[1].v = tmp;
}
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b)
+{
+ __m128 eq = _mm_cmpeq_ps(a.v, b.v);
+ return Packet2cf(pand<Packet4f>(eq, vec4f_swizzle1(eq, 1, 0, 3, 2)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b)
+{
+ __m128d eq = _mm_cmpeq_pd(a.v, b.v);
+ return Packet1cd(pand<Packet2d>(eq, vec2d_swizzle1(eq, 1, 0)));
+}
+
template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) {
__m128d result = pblend<Packet2d>(ifPacket, _mm_castps_pd(thenPacket.v), _mm_castps_pd(elsePacket.v));
return Packet2cf(_mm_castpd_ps(result));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pinsertfirst(const Packet2cf& a, std::complex<float> b)
-{
- return Packet2cf(_mm_loadl_pi(a.v, reinterpret_cast<const __m64*>(&b)));
+template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
+ return psqrt_complex<Packet1cd>(a);
}
-template<> EIGEN_STRONG_INLINE Packet1cd pinsertfirst(const Packet1cd&, std::complex<double> b)
-{
- return pset1<Packet1cd>(b);
-}
-
-template<> EIGEN_STRONG_INLINE Packet2cf pinsertlast(const Packet2cf& a, std::complex<float> b)
-{
- return Packet2cf(_mm_loadh_pi(a.v, reinterpret_cast<const __m64*>(&b)));
-}
-
-template<> EIGEN_STRONG_INLINE Packet1cd pinsertlast(const Packet1cd&, std::complex<double> b)
-{
- return pset1<Packet1cd>(b);
+template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
+ return psqrt_complex<Packet2cf>(a);
}
} // end namespace internal
-
} // end namespace Eigen
#endif // EIGEN_COMPLEX_SSE_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/MathFunctions.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/MathFunctions.h
index 7b5f948..8736d0d 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/MathFunctions.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/MathFunctions.h
@@ -8,7 +8,7 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-/* The sin, cos, exp, and log functions of this file come from
+/* The sin and cos and functions of this file come from
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
*/
@@ -20,426 +20,57 @@
namespace internal {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4f plog<Packet4f>(const Packet4f& _x)
-{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
+Packet4f plog<Packet4f>(const Packet4f& _x) {
+ return plog_float(_x);
+}
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inv_mant_mask, ~0x7f800000);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet2d plog<Packet2d>(const Packet2d& _x) {
+ return plog_double(_x);
+}
- /* the smallest non denormalized float number */
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(min_norm_pos, 0x00800000);
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_inf, 0xff800000);//-1.f/0.f);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f plog2<Packet4f>(const Packet4f& _x) {
+ return plog2_float(_x);
+}
- /* natural logarithm computed for 4 simultaneous float
- return NaN for x <= 0
- */
- _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, - 1.1514610310E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, - 1.2420140846E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, + 1.4249322787E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, - 1.6668057665E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, + 2.0000714765E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, - 2.4999993993E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, + 3.3333331174E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet2d plog2<Packet2d>(const Packet2d& _x) {
+ return plog2_double(_x);
+}
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f plog1p<Packet4f>(const Packet4f& _x) {
+ return generic_plog1p(_x);
+}
- Packet4i emm0;
-
- Packet4f invalid_mask = _mm_cmpnge_ps(x, _mm_setzero_ps()); // not greater equal is true if x is NaN
- Packet4f iszero_mask = _mm_cmpeq_ps(x, _mm_setzero_ps());
-
- x = pmax(x, p4f_min_norm_pos); /* cut off denormalized stuff */
- emm0 = _mm_srli_epi32(_mm_castps_si128(x), 23);
-
- /* keep only the fractional part */
- x = _mm_and_ps(x, p4f_inv_mant_mask);
- x = _mm_or_ps(x, p4f_half);
-
- emm0 = _mm_sub_epi32(emm0, p4i_0x7f);
- Packet4f e = padd(Packet4f(_mm_cvtepi32_ps(emm0)), p4f_1);
-
- /* part2:
- if( x < SQRTHF ) {
- e -= 1;
- x = x + x - 1.0;
- } else { x = x - 1.0; }
- */
- Packet4f mask = _mm_cmplt_ps(x, p4f_cephes_SQRTHF);
- Packet4f tmp = pand(x, mask);
- x = psub(x, p4f_1);
- e = psub(e, pand(p4f_1, mask));
- x = padd(x, tmp);
-
- Packet4f x2 = pmul(x,x);
- Packet4f x3 = pmul(x2,x);
-
- Packet4f y, y1, y2;
- y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
- y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
- y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
- y = pmadd(y , x, p4f_cephes_log_p2);
- y1 = pmadd(y1, x, p4f_cephes_log_p5);
- y2 = pmadd(y2, x, p4f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- y1 = pmul(e, p4f_cephes_log_q1);
- tmp = pmul(x2, p4f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p4f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
- // negative arg will be NAN, 0 will be -INF
- return _mm_or_ps(_mm_andnot_ps(iszero_mask, _mm_or_ps(x, invalid_mask)),
- _mm_and_ps(iszero_mask, p4f_minus_inf));
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f pexpm1<Packet4f>(const Packet4f& _x) {
+ return generic_expm1(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexp<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
-
-
- _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
- _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
-
- _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
-
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
-
- Packet4f tmp, fx;
- Packet4i emm0;
-
- // clamp x
- x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo);
-
- /* express exp(x) as exp(g + n*log(2)) */
- fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half);
-
-#ifdef EIGEN_VECTORIZE_SSE4_1
- fx = _mm_floor_ps(fx);
-#else
- emm0 = _mm_cvttps_epi32(fx);
- tmp = _mm_cvtepi32_ps(emm0);
- /* if greater, substract 1 */
- Packet4f mask = _mm_cmpgt_ps(tmp, fx);
- mask = _mm_and_ps(mask, p4f_1);
- fx = psub(tmp, mask);
-#endif
-
- tmp = pmul(fx, p4f_cephes_exp_C1);
- Packet4f z = pmul(fx, p4f_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- z = pmul(x,x);
-
- Packet4f y = p4f_cephes_exp_p0;
- y = pmadd(y, x, p4f_cephes_exp_p1);
- y = pmadd(y, x, p4f_cephes_exp_p2);
- y = pmadd(y, x, p4f_cephes_exp_p3);
- y = pmadd(y, x, p4f_cephes_exp_p4);
- y = pmadd(y, x, p4f_cephes_exp_p5);
- y = pmadd(y, z, x);
- y = padd(y, p4f_1);
-
- // build 2^n
- emm0 = _mm_cvttps_epi32(fx);
- emm0 = _mm_add_epi32(emm0, p4i_0x7f);
- emm0 = _mm_slli_epi32(emm0, 23);
- return pmax(pmul(y, Packet4f(_mm_castsi128_ps(emm0))), _x);
+ return pexp_float(_x);
}
+
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet2d pexp<Packet2d>(const Packet2d& _x)
+Packet2d pexp<Packet2d>(const Packet2d& x)
{
- Packet2d x = _x;
-
- _EIGEN_DECLARE_CONST_Packet2d(1 , 1.0);
- _EIGEN_DECLARE_CONST_Packet2d(2 , 2.0);
- _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
-
- _EIGEN_DECLARE_CONST_Packet2d(exp_hi, 709.437);
- _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -709.436139303);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6);
- static const __m128i p4i_1023_0 = _mm_setr_epi32(1023, 1023, 0, 0);
-
- Packet2d tmp, fx;
- Packet4i emm0;
-
- // clamp x
- x = pmax(pmin(x, p2d_exp_hi), p2d_exp_lo);
- /* express exp(x) as exp(g + n*log(2)) */
- fx = pmadd(p2d_cephes_LOG2EF, x, p2d_half);
-
-#ifdef EIGEN_VECTORIZE_SSE4_1
- fx = _mm_floor_pd(fx);
-#else
- emm0 = _mm_cvttpd_epi32(fx);
- tmp = _mm_cvtepi32_pd(emm0);
- /* if greater, substract 1 */
- Packet2d mask = _mm_cmpgt_pd(tmp, fx);
- mask = _mm_and_pd(mask, p2d_1);
- fx = psub(tmp, mask);
-#endif
-
- tmp = pmul(fx, p2d_cephes_exp_C1);
- Packet2d z = pmul(fx, p2d_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- Packet2d x2 = pmul(x,x);
-
- Packet2d px = p2d_cephes_exp_p0;
- px = pmadd(px, x2, p2d_cephes_exp_p1);
- px = pmadd(px, x2, p2d_cephes_exp_p2);
- px = pmul (px, x);
-
- Packet2d qx = p2d_cephes_exp_q0;
- qx = pmadd(qx, x2, p2d_cephes_exp_q1);
- qx = pmadd(qx, x2, p2d_cephes_exp_q2);
- qx = pmadd(qx, x2, p2d_cephes_exp_q3);
-
- x = pdiv(px,psub(qx,px));
- x = pmadd(p2d_2,x,p2d_1);
-
- // build 2^n
- emm0 = _mm_cvttpd_epi32(fx);
- emm0 = _mm_add_epi32(emm0, p4i_1023_0);
- emm0 = _mm_slli_epi32(emm0, 20);
- emm0 = _mm_shuffle_epi32(emm0, _MM_SHUFFLE(1,2,0,3));
- return pmax(pmul(x, Packet2d(_mm_castsi128_pd(emm0))), _x);
+ return pexp_double(x);
}
-/* evaluation of 4 sines at onces, using SSE2 intrinsics.
-
- The code is the exact rewriting of the cephes sinf function.
- Precision is excellent as long as x < 8192 (I did not bother to
- take into account the special handling they have for greater values
- -- it does not return garbage for arguments over 8192, though, but
- the extra precision is missing).
-
- Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
- surprising but correct result.
-*/
-
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psin<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
-
- _EIGEN_DECLARE_CONST_Packet4i(1, 1);
- _EIGEN_DECLARE_CONST_Packet4i(not1, ~1);
- _EIGEN_DECLARE_CONST_Packet4i(2, 2);
- _EIGEN_DECLARE_CONST_Packet4i(4, 4);
-
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(sign_mask, 0x80000000);
-
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1,-0.78515625f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948E-005f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765E-003f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827E-002f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4 / M_PI
-
- Packet4f xmm1, xmm2, xmm3, sign_bit, y;
-
- Packet4i emm0, emm2;
- sign_bit = x;
- /* take the absolute value */
- x = pabs(x);
-
- /* take the modulo */
-
- /* extract the sign bit (upper one) */
- sign_bit = _mm_and_ps(sign_bit, p4f_sign_mask);
-
- /* scale by 4/Pi */
- y = pmul(x, p4f_cephes_FOPI);
-
- /* store the integer part of y in mm0 */
- emm2 = _mm_cvttps_epi32(y);
- /* j=(j+1) & (~1) (see the cephes sources) */
- emm2 = _mm_add_epi32(emm2, p4i_1);
- emm2 = _mm_and_si128(emm2, p4i_not1);
- y = _mm_cvtepi32_ps(emm2);
- /* get the swap sign flag */
- emm0 = _mm_and_si128(emm2, p4i_4);
- emm0 = _mm_slli_epi32(emm0, 29);
- /* get the polynom selection mask
- there is one polynom for 0 <= x <= Pi/4
- and another one for Pi/4<x<=Pi/2
-
- Both branches will be computed.
- */
- emm2 = _mm_and_si128(emm2, p4i_2);
- emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
-
- Packet4f swap_sign_bit = _mm_castsi128_ps(emm0);
- Packet4f poly_mask = _mm_castsi128_ps(emm2);
- sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
-
- /* The magic pass: "Extended precision modular arithmetic"
- x = ((x - y * DP1) - y * DP2) - y * DP3; */
- xmm1 = pmul(y, p4f_minus_cephes_DP1);
- xmm2 = pmul(y, p4f_minus_cephes_DP2);
- xmm3 = pmul(y, p4f_minus_cephes_DP3);
- x = padd(x, xmm1);
- x = padd(x, xmm2);
- x = padd(x, xmm3);
-
- /* Evaluate the first polynom (0 <= x <= Pi/4) */
- y = p4f_coscof_p0;
- Packet4f z = _mm_mul_ps(x,x);
-
- y = pmadd(y, z, p4f_coscof_p1);
- y = pmadd(y, z, p4f_coscof_p2);
- y = pmul(y, z);
- y = pmul(y, z);
- Packet4f tmp = pmul(z, p4f_half);
- y = psub(y, tmp);
- y = padd(y, p4f_1);
-
- /* Evaluate the second polynom (Pi/4 <= x <= 0) */
-
- Packet4f y2 = p4f_sincof_p0;
- y2 = pmadd(y2, z, p4f_sincof_p1);
- y2 = pmadd(y2, z, p4f_sincof_p2);
- y2 = pmul(y2, z);
- y2 = pmul(y2, x);
- y2 = padd(y2, x);
-
- /* select the correct result from the two polynoms */
- y2 = _mm_and_ps(poly_mask, y2);
- y = _mm_andnot_ps(poly_mask, y);
- y = _mm_or_ps(y,y2);
- /* update the sign */
- return _mm_xor_ps(y, sign_bit);
+ return psin_float(_x);
}
-/* almost the same as psin */
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pcos<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
-
- _EIGEN_DECLARE_CONST_Packet4i(1, 1);
- _EIGEN_DECLARE_CONST_Packet4i(not1, ~1);
- _EIGEN_DECLARE_CONST_Packet4i(2, 2);
- _EIGEN_DECLARE_CONST_Packet4i(4, 4);
-
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1,-0.78515625f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948E-005f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765E-003f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827E-002f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4 / M_PI
-
- Packet4f xmm1, xmm2, xmm3, y;
- Packet4i emm0, emm2;
-
- x = pabs(x);
-
- /* scale by 4/Pi */
- y = pmul(x, p4f_cephes_FOPI);
-
- /* get the integer part of y */
- emm2 = _mm_cvttps_epi32(y);
- /* j=(j+1) & (~1) (see the cephes sources) */
- emm2 = _mm_add_epi32(emm2, p4i_1);
- emm2 = _mm_and_si128(emm2, p4i_not1);
- y = _mm_cvtepi32_ps(emm2);
-
- emm2 = _mm_sub_epi32(emm2, p4i_2);
-
- /* get the swap sign flag */
- emm0 = _mm_andnot_si128(emm2, p4i_4);
- emm0 = _mm_slli_epi32(emm0, 29);
- /* get the polynom selection mask */
- emm2 = _mm_and_si128(emm2, p4i_2);
- emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
-
- Packet4f sign_bit = _mm_castsi128_ps(emm0);
- Packet4f poly_mask = _mm_castsi128_ps(emm2);
-
- /* The magic pass: "Extended precision modular arithmetic"
- x = ((x - y * DP1) - y * DP2) - y * DP3; */
- xmm1 = pmul(y, p4f_minus_cephes_DP1);
- xmm2 = pmul(y, p4f_minus_cephes_DP2);
- xmm3 = pmul(y, p4f_minus_cephes_DP3);
- x = padd(x, xmm1);
- x = padd(x, xmm2);
- x = padd(x, xmm3);
-
- /* Evaluate the first polynom (0 <= x <= Pi/4) */
- y = p4f_coscof_p0;
- Packet4f z = pmul(x,x);
-
- y = pmadd(y,z,p4f_coscof_p1);
- y = pmadd(y,z,p4f_coscof_p2);
- y = pmul(y, z);
- y = pmul(y, z);
- Packet4f tmp = _mm_mul_ps(z, p4f_half);
- y = psub(y, tmp);
- y = padd(y, p4f_1);
-
- /* Evaluate the second polynom (Pi/4 <= x <= 0) */
- Packet4f y2 = p4f_sincof_p0;
- y2 = pmadd(y2, z, p4f_sincof_p1);
- y2 = pmadd(y2, z, p4f_sincof_p2);
- y2 = pmul(y2, z);
- y2 = pmadd(y2, x, x);
-
- /* select the correct result from the two polynoms */
- y2 = _mm_and_ps(poly_mask, y2);
- y = _mm_andnot_ps(poly_mask, y);
- y = _mm_or_ps(y,y2);
-
- /* update the sign */
- return _mm_xor_ps(y, sign_bit);
+ return pcos_float(_x);
}
#if EIGEN_FAST_MATH
@@ -455,17 +86,17 @@
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& _x)
{
- Packet4f half = pmul(_x, pset1<Packet4f>(.5f));
- Packet4f denormal_mask = _mm_and_ps(
- _mm_cmpge_ps(_x, _mm_setzero_ps()),
- _mm_cmplt_ps(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())));
+ Packet4f minus_half_x = pmul(_x, pset1<Packet4f>(-0.5f));
+ Packet4f denormal_mask = pandnot(
+ pcmp_lt(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())),
+ pcmp_lt(_x, pzero(_x)));
// Compute approximate reciprocal sqrt.
Packet4f x = _mm_rsqrt_ps(_x);
// Do a single step of Newton's iteration.
- x = pmul(x, psub(pset1<Packet4f>(1.5f), pmul(half, pmul(x,x))));
+ x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1<Packet4f>(1.5f)));
// Flush results for denormals to zero.
- return _mm_andnot_ps(denormal_mask, pmul(_x,x));
+ return pandnot(pmul(_x,x), denormal_mask);
}
#else
@@ -478,41 +109,48 @@
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d psqrt<Packet2d>(const Packet2d& x) { return _mm_sqrt_pd(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet16b psqrt<Packet16b>(const Packet16b& x) { return x; }
+
#if EIGEN_FAST_MATH
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& _x) {
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inf, 0x7f800000);
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(nan, 0x7fc00000);
_EIGEN_DECLARE_CONST_Packet4f(one_point_five, 1.5f);
_EIGEN_DECLARE_CONST_Packet4f(minus_half, -0.5f);
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(flt_min, 0x00800000);
+ _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inf, 0x7f800000u);
+ _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(flt_min, 0x00800000u);
Packet4f neg_half = pmul(_x, p4f_minus_half);
- // select only the inverse sqrt of positive normal inputs (denormals are
- // flushed to zero and cause infs as well).
- Packet4f le_zero_mask = _mm_cmple_ps(_x, p4f_flt_min);
- Packet4f x = _mm_andnot_ps(le_zero_mask, _mm_rsqrt_ps(_x));
+ // Identity infinite, zero, negative and denormal arguments.
+ Packet4f lt_min_mask = _mm_cmplt_ps(_x, p4f_flt_min);
+ Packet4f inf_mask = _mm_cmpeq_ps(_x, p4f_inf);
+ Packet4f not_normal_finite_mask = _mm_or_ps(lt_min_mask, inf_mask);
- // Fill in NaNs and Infs for the negative/zero entries.
- Packet4f neg_mask = _mm_cmplt_ps(_x, _mm_setzero_ps());
- Packet4f zero_mask = _mm_andnot_ps(neg_mask, le_zero_mask);
- Packet4f infs_and_nans = _mm_or_ps(_mm_and_ps(neg_mask, p4f_nan),
- _mm_and_ps(zero_mask, p4f_inf));
+ // Compute an approximate result using the rsqrt intrinsic.
+ Packet4f y_approx = _mm_rsqrt_ps(_x);
- // Do a single step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p4f_one_point_five));
+ // Do a single step of Newton-Raphson iteration to improve the approximation.
+ // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
+ // It is essential to evaluate the inner term like this because forming
+ // y_n^2 may over- or underflow.
+ Packet4f y_newton = pmul(
+ y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p4f_one_point_five));
- // Insert NaNs and Infs in all the right places.
- return _mm_or_ps(x, infs_and_nans);
+ // Select the result of the Newton-Raphson step for positive normal arguments.
+ // For other arguments, choose the output of the intrinsic. This will
+ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
+ // x is zero or a positive denormalized float (equivalent to flushing positive
+ // denormalized inputs to zero).
+ return pselect<Packet4f>(not_normal_finite_mask, y_approx, y_newton);
}
#else
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation.
+ // Unfortunately we can't use the much faster mm_rsqrt_ps since it only provides an approximation.
return _mm_div_ps(pset1<Packet4f>(1.0f), _mm_sqrt_ps(x));
}
@@ -520,7 +158,6 @@
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return _mm_div_pd(pset1<Packet2d>(1.0), _mm_sqrt_pd(x));
}
@@ -548,7 +185,7 @@
{
#if EIGEN_COMP_GNUC_STRICT
// This works around a GCC bug generating poor code for _mm_sqrt_pd
- // See https://bitbucket.org/eigen/eigen/commits/14f468dba4d350d7c19c9b93072e19f7b3df563b
+ // See https://gitlab.com/libeigen/eigen/commit/8dca9f97e38970
return internal::pfirst(internal::Packet2d(__builtin_ia32_sqrtsd(_mm_set_sd(x))));
#else
return internal::pfirst(internal::Packet2d(_mm_sqrt_pd(_mm_set_sd(x))));
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/PacketMath.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/PacketMath.h
index 60e2517..db102c7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -18,13 +18,15 @@
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
-#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
+#if !defined(EIGEN_VECTORIZE_AVX) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS)
+// 32 bits => 8 registers
+// 64 bits => 16 registers
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*))
#endif
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD 1
+#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
#endif
@@ -34,47 +36,75 @@
// One solution is to increase ABI version using -fabi-version=4 (or greater).
// Otherwise, we workaround this inconvenience by wrapping 128bit types into the following helper
// structure:
-template<typename T>
-struct eigen_packet_wrapper
-{
- EIGEN_ALWAYS_INLINE operator T&() { return m_val; }
- EIGEN_ALWAYS_INLINE operator const T&() const { return m_val; }
- EIGEN_ALWAYS_INLINE eigen_packet_wrapper() {}
- EIGEN_ALWAYS_INLINE eigen_packet_wrapper(const T &v) : m_val(v) {}
- EIGEN_ALWAYS_INLINE eigen_packet_wrapper& operator=(const T &v) {
- m_val = v;
- return *this;
- }
-
- T m_val;
-};
typedef eigen_packet_wrapper<__m128> Packet4f;
-typedef eigen_packet_wrapper<__m128i> Packet4i;
typedef eigen_packet_wrapper<__m128d> Packet2d;
#else
typedef __m128 Packet4f;
-typedef __m128i Packet4i;
typedef __m128d Packet2d;
#endif
+typedef eigen_packet_wrapper<__m128i, 0> Packet4i;
+typedef eigen_packet_wrapper<__m128i, 1> Packet16b;
+
template<> struct is_arithmetic<__m128> { enum { value = true }; };
template<> struct is_arithmetic<__m128i> { enum { value = true }; };
template<> struct is_arithmetic<__m128d> { enum { value = true }; };
+template<> struct is_arithmetic<Packet4i> { enum { value = true }; };
+template<> struct is_arithmetic<Packet16b> { enum { value = true }; };
+template<int p, int q, int r, int s>
+struct shuffle_mask{
+ enum { mask = (s)<<6|(r)<<4|(q)<<2|(p) };
+};
+
+// TODO: change the implementation of all swizzle* ops from macro to template,
#define vec4f_swizzle1(v,p,q,r,s) \
- (_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), ((s)<<6|(r)<<4|(q)<<2|(p)))))
+ Packet4f(_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), (shuffle_mask<p,q,r,s>::mask))))
#define vec4i_swizzle1(v,p,q,r,s) \
- (_mm_shuffle_epi32( v, ((s)<<6|(r)<<4|(q)<<2|(p))))
+ Packet4i(_mm_shuffle_epi32( v, (shuffle_mask<p,q,r,s>::mask)))
#define vec2d_swizzle1(v,p,q) \
- (_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), ((q*2+1)<<6|(q*2)<<4|(p*2+1)<<2|(p*2)))))
-
+ Packet2d(_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), (shuffle_mask<2*p,2*p+1,2*q,2*q+1>::mask))))
+
#define vec4f_swizzle2(a,b,p,q,r,s) \
- (_mm_shuffle_ps( (a), (b), ((s)<<6|(r)<<4|(q)<<2|(p))))
+ Packet4f(_mm_shuffle_ps( (a), (b), (shuffle_mask<p,q,r,s>::mask)))
#define vec4i_swizzle2(a,b,p,q,r,s) \
- (_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), ((s)<<6|(r)<<4|(q)<<2|(p))))))
+ Packet4i(_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), (shuffle_mask<p,q,r,s>::mask)))))
+
+EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_movelh_ps(a,b));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_movehl(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_movehl_ps(a,b));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpacklo(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_unpacklo_ps(a,b));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpackhi(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_unpackhi_ps(a,b));
+}
+#define vec4f_duplane(a,p) \
+ vec4f_swizzle2(a,a,p,p,p,p)
+
+#define vec2d_swizzle2(a,b,mask) \
+ Packet2d(_mm_shuffle_pd(a,b,mask))
+
+EIGEN_STRONG_INLINE Packet2d vec2d_unpacklo(const Packet2d& a, const Packet2d& b)
+{
+ return Packet2d(_mm_unpacklo_pd(a,b));
+}
+EIGEN_STRONG_INLINE Packet2d vec2d_unpackhi(const Packet2d& a, const Packet2d& b)
+{
+ return Packet2d(_mm_unpackhi_pd(a,b));
+}
+#define vec2d_duplane(a,p) \
+ vec2d_swizzle2(a,a,(p<<1)|p)
#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
const Packet4f p4f_##NAME = pset1<Packet4f>(X)
@@ -83,7 +113,7 @@
const Packet2d p2d_##NAME = pset1<Packet2d>(X)
#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \
- const Packet4f p4f_##NAME = _mm_castsi128_ps(pset1<Packet4i>(X))
+ const Packet4f p4f_##NAME = pset1frombits<Packet4f>(X)
#define _EIGEN_DECLARE_CONST_Packet4i(NAME,X) \
const Packet4i p4i_##NAME = pset1<Packet4i>(X)
@@ -92,36 +122,41 @@
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
// to leverage AVX instructions.
#ifndef EIGEN_VECTORIZE_AVX
-template<> struct packet_traits<float> : default_packet_traits
-{
+template <>
+struct packet_traits<float> : default_packet_traits {
typedef Packet4f type;
typedef Packet4f half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
+ size = 4,
HasHalfPacket = 0,
- HasDiv = 1,
- HasSin = EIGEN_FAST_MATH,
- HasCos = EIGEN_FAST_MATH,
- HasLog = 1,
- HasExp = 1,
+ HasCmp = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasNdtri = 1,
+ HasExp = 1,
+ HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasTanh = EIGEN_FAST_MATH,
- HasBlend = 1
-
-#ifdef EIGEN_VECTORIZE_SSE4_1
- ,
- HasRound = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 1,
+ HasCeil = 1,
HasFloor = 1,
- HasCeil = 1
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ HasRound = 1,
#endif
+ HasRint = 1
};
};
-template<> struct packet_traits<double> : default_packet_traits
-{
+template <>
+struct packet_traits<double> : default_packet_traits {
typedef Packet2d type;
typedef Packet2d half;
enum {
@@ -130,18 +165,19 @@
size=2,
HasHalfPacket = 0,
+ HasCmp = 1,
HasDiv = 1,
+ HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasBlend = 1
-
-#ifdef EIGEN_VECTORIZE_SSE4_1
- ,
- HasRound = 1,
+ HasBlend = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ HasRound = 1,
#endif
+ HasRint = 1
};
};
#endif
@@ -154,13 +190,56 @@
AlignedOnScalar = 1,
size=4,
+ HasShift = 1,
HasBlend = 1
};
};
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
-template<> struct unpacket_traits<Packet4i> { typedef int type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
+template<> struct packet_traits<bool> : default_packet_traits
+{
+ typedef Packet16b type;
+ typedef Packet16b half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ HasHalfPacket = 0,
+ size=16,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 0,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasConj = 0,
+ HasSqrt = 1
+ };
+};
+
+template<> struct unpacket_traits<Packet4f> {
+ typedef float type;
+ typedef Packet4f half;
+ typedef Packet4i integer_packet;
+ enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet2d> {
+ typedef double type;
+ typedef Packet2d half;
+ enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet4i> {
+ typedef int type;
+ typedef Packet4i half;
+ enum {size=4, alignment=Aligned16, vectorizable=false, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet16b> {
+ typedef bool type;
+ typedef Packet16b half;
+ enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
#ifndef EIGEN_VECTORIZE_AVX
template<> struct scalar_div_cost<float,true> { enum { value = 7 }; };
@@ -179,6 +258,18 @@
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) { return _mm_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) { return _mm_set1_epi32(from); }
#endif
+template<> EIGEN_STRONG_INLINE Packet16b pset1<Packet16b>(const bool& from) { return _mm_set1_epi8(static_cast<char>(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) { return _mm_castsi128_ps(pset1<Packet4i>(from)); }
+template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f peven_mask(const Packet4f& /*a*/) { return _mm_castsi128_ps(_mm_set_epi32(0, -1, 0, -1)); }
+template<> EIGEN_STRONG_INLINE Packet4i peven_mask(const Packet4i& /*a*/) { return _mm_set_epi32(0, -1, 0, -1); }
+template<> EIGEN_STRONG_INLINE Packet2d peven_mask(const Packet2d& /*a*/) { return _mm_castsi128_pd(_mm_set_epi32(0, 0, -1, -1)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet4i pzero(const Packet4i& /*a*/) { return _mm_setzero_si128(); }
// GCC generates a shufps instruction for _mm_set1_ps/_mm_load1_ps instead of the more efficient pshufd instruction.
// However, using inrinsics for pset1 makes gcc to generate crappy code in some cases (see bug 203)
@@ -190,7 +281,7 @@
return vec4f_swizzle1(_mm_load_ss(from),0,0,0,0);
}
#endif
-
+
template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return _mm_add_ps(pset1<Packet4f>(a), _mm_set_ps(3,2,1,0)); }
template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a) { return _mm_add_pd(pset1<Packet2d>(a),_mm_set_pd(1,0)); }
template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return _mm_add_epi32(pset1<Packet4i>(a),_mm_set_epi32(3,2,1,0)); }
@@ -199,9 +290,34 @@
template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_add_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b psub<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b);
+template<> EIGEN_STRONG_INLINE Packet4f paddsub<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+#ifdef EIGEN_VECTORIZE_SSE3
+ return _mm_addsub_ps(a,b);
+#else
+ const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x0,0x80000000,0x0));
+ return padd(a, pxor(mask, b));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& , const Packet2d& );
+template<> EIGEN_STRONG_INLINE Packet2d paddsub<Packet2d>(const Packet2d& a, const Packet2d& b)
+{
+#ifdef EIGEN_VECTORIZE_SSE3
+ return _mm_addsub_pd(a,b);
+#else
+ const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x80000000,0x0,0x0));
+ return padd(a, pxor(mask, b));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a)
{
@@ -218,6 +334,11 @@
return psub(Packet4i(_mm_setr_epi32(0,0,0,0)), a);
}
+template<> EIGEN_STRONG_INLINE Packet16b pnegate(const Packet16b& a)
+{
+ return psub(pset1<Packet16b>(false), a);
+}
+
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
@@ -240,18 +361,126 @@
#endif
}
+template<> EIGEN_STRONG_INLINE Packet16b pmul<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_div_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_div_pd(a,b); }
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd(pmul(a,b), c); }
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); }
#endif
-template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_min_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_min_pd(a,b); }
+#ifdef EIGEN_VECTORIZE_SSE4_1
+template<> EIGEN_DEVICE_FUNC inline Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
+ return _mm_blendv_ps(b,a,mask);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b) {
+ return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b),_mm_castsi128_ps(a),_mm_castsi128_ps(mask)));
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet2d pselect(const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return _mm_blendv_pd(b,a,mask); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) {
+ return _mm_blendv_epi8(b,a,mask);
+}
+#else
+template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) {
+ Packet16b a_part = _mm_and_si128(mask, a);
+ Packet16b b_part = _mm_andnot_si128(mask, b);
+ return _mm_or_si128(a_part, b_part);
+}
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4i ptrue<Packet4i>(const Packet4i& a) { return _mm_cmpeq_epi32(a, a); }
+template<> EIGEN_STRONG_INLINE Packet16b ptrue<Packet16b>(const Packet16b& a) { return _mm_cmpeq_epi8(a, a); }
+template<> EIGEN_STRONG_INLINE Packet4f
+ptrue<Packet4f>(const Packet4f& a) {
+ Packet4i b = _mm_castps_si128(a);
+ return _mm_castsi128_ps(_mm_cmpeq_epi32(b, b));
+}
+template<> EIGEN_STRONG_INLINE Packet2d
+ptrue<Packet2d>(const Packet2d& a) {
+ Packet4i b = _mm_castpd_si128(a);
+ return _mm_castsi128_pd(_mm_cmpeq_epi32(b, b));
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_and_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_and_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b pand<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_or_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_or_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_or_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b por<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_xor_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_xor_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b pxor<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_andnot_ps(b,a); }
+template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_andnot_pd(b,a); }
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_andnot_si128(b,a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { return _mm_cmpnge_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return _mm_cmple_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return _mm_cmplt_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { return _mm_cmpnge_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b pcmp_eq(const Packet16b& a, const Packet16b& b) { return _mm_cmpeq_epi8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return por(pcmp_lt(a,b), pcmp_eq(a,b)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_min_ps, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet4f res;
+ asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet4f res = b;
+ asm("minps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::min.
+ return _mm_min_ps(b, a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_min_pd, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet2d res;
+ asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet2d res = b;
+ asm("minpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::min.
+ return _mm_min_pd(b, a);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b)
{
#ifdef EIGEN_VECTORIZE_SSE4_1
@@ -263,8 +492,45 @@
#endif
}
-template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_max_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_max_pd(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_max_ps, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet4f res;
+ asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet4f res = b;
+ asm("maxps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::max.
+ return _mm_max_ps(b, a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_max_pd, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet2d res;
+ asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet2d res = b;
+ asm("maxpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::max.
+ return _mm_max_pd(b, a);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b)
{
#ifdef EIGEN_VECTORIZE_SSE4_1
@@ -276,36 +542,180 @@
#endif
}
+template <typename Packet, typename Op>
+EIGEN_STRONG_INLINE Packet pminmax_propagate_numbers(const Packet& a, const Packet& b, Op op) {
+ // In this implementation, we take advantage of the fact that pmin/pmax for SSE
+ // always return a if either a or b is NaN.
+ Packet not_nan_mask_a = pcmp_eq(a, a);
+ Packet m = op(a, b);
+ return pselect<Packet>(not_nan_mask_a, m, b);
+}
+
+template <typename Packet, typename Op>
+EIGEN_STRONG_INLINE Packet pminmax_propagate_nan(const Packet& a, const Packet& b, Op op) {
+ // In this implementation, we take advantage of the fact that pmin/pmax for SSE
+ // always return a if either a or b is NaN.
+ Packet not_nan_mask_a = pcmp_eq(a, a);
+ Packet m = op(b, a);
+ return pselect<Packet>(not_nan_mask_a, m, a);
+}
+
+// Add specializations for min/max with prescribed NaN progation.
+template<>
+EIGEN_STRONG_INLINE Packet4f pmin<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmin<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet2d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4f pmax<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmax<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet2d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4f pmin<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmin<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet2d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4f pmax<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmax<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet2d>);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a) { return _mm_srai_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right (const Packet4i& a) { return _mm_srli_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left (const Packet4i& a) { return _mm_slli_epi32(a,N); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a)
+{
+ const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF));
+ return _mm_and_ps(a,mask);
+}
+template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a)
+{
+ const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF));
+ return _mm_and_pd(a,mask);
+}
+template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
+{
+ #ifdef EIGEN_VECTORIZE_SSSE3
+ return _mm_abs_epi32(a);
+ #else
+ Packet4i aux = _mm_srai_epi32(a,31);
+ return _mm_sub_epi32(_mm_xor_si128(a,aux),aux);
+ #endif
+}
+
#ifdef EIGEN_VECTORIZE_SSE4_1
-template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) { return _mm_round_ps(a, 0); }
-template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) { return _mm_round_pd(a, 0); }
+template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
+{
+ // Unfortunatly _mm_round_ps doesn't have a rounding mode to implement numext::round.
+ const Packet4f mask = pset1frombits<Packet4f>(0x80000000u);
+ const Packet4f prev0dot5 = pset1frombits<Packet4f>(0x3EFFFFFFu);
+ return _mm_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a)
+{
+ const Packet2d mask = _mm_castsi128_pd(_mm_set_epi64x(0x8000000000000000ull, 0x8000000000000000ull));
+ const Packet2d prev0dot5 = _mm_castsi128_pd(_mm_set_epi64x(0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull));
+ return _mm_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a) { return _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a) { return _mm_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) { return _mm_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return _mm_ceil_pd(a); }
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) { return _mm_floor_ps(a); }
template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return _mm_floor_pd(a); }
+#else
+template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet4f limit = pset1<Packet4f>(static_cast<float>(1<<23));
+ const Packet4f abs_a = pabs(a);
+ Packet4f r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) {
+ // Adds and subtracts signum(a) * 2^52 to force rounding.
+ const Packet2d limit = pset1<Packet2d>(static_cast<double>(1ull<<52));
+ const Packet2d abs_a = pabs(a);
+ Packet2d r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If greater, subtract one.
+ Packet4f mask = _mm_cmpgt_ps(tmp, a);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ Packet2d tmp = print<Packet2d>(a);
+ // If greater, subtract one.
+ Packet2d mask = _mm_cmpgt_pd(tmp, a);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If smaller, add one.
+ Packet4f mask = _mm_cmplt_ps(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ Packet2d tmp = print<Packet2d>(a);
+ // If smaller, add one.
+ Packet2d mask = _mm_cmplt_pd(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
#endif
-template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_and_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_and_si128(a,b); }
-
-template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_or_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_or_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_or_si128(a,b); }
-
-template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_xor_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_xor_si128(a,b); }
-
-template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_andnot_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_andnot_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_andnot_si128(a,b); }
-
template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ps(from); }
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_pd(from); }
template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet16b pload<Packet16b>(const bool* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
#if EIGEN_COMP_MSVC
template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) {
@@ -340,6 +750,10 @@
EIGEN_DEBUG_UNALIGNED_LOAD
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
}
+template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
@@ -355,13 +769,32 @@
return vec4i_swizzle1(tmp, 0, 0, 1, 1);
}
+// Loads 8 bools from memory and returns the packet
+// {b0, b0, b1, b1, b2, b2, b3, b3, b4, b4, b5, b5, b6, b6, b7, b7}
+template<> EIGEN_STRONG_INLINE Packet16b ploaddup<Packet16b>(const bool* from)
+{
+ __m128i tmp = _mm_castpd_si128(pload1<Packet2d>(reinterpret_cast<const double*>(from)));
+ return _mm_unpacklo_epi8(tmp, tmp);
+}
+
+// Loads 4 bools from memory and returns the packet
+// {b0, b0 b0, b0, b1, b1, b1, b1, b2, b2, b2, b2, b3, b3, b3, b3}
+template<> EIGEN_STRONG_INLINE Packet16b
+ploadquad<Packet16b>(const bool* from) {
+ __m128i tmp = _mm_castps_si128(pload1<Packet4f>(reinterpret_cast<const float*>(from)));
+ tmp = _mm_unpacklo_epi8(tmp, tmp);
+ return _mm_unpacklo_epi16(tmp, tmp);
+}
+
template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstore<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{
@@ -374,7 +807,15 @@
template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int, Packet4i>(const int* from, Index stride)
{
return _mm_set_epi32(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
- }
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet16b pgather<bool, Packet16b>(const bool* from, Index stride)
+{
+ return _mm_set_epi8(from[15*stride], from[14*stride], from[13*stride], from[12*stride],
+ from[11*stride], from[10*stride], from[9*stride], from[8*stride],
+ from[7*stride], from[6*stride], from[5*stride], from[4*stride],
+ from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
+}
template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
{
@@ -395,6 +836,14 @@
to[stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2));
to[stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3));
}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<bool, Packet16b>(bool* to, const Packet16b& from, Index stride)
+{
+ to[4*stride*0] = _mm_cvtsi128_si32(from);
+ to[4*stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 1));
+ to[4*stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2));
+ to[4*stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3));
+}
+
// some compilers might be tempted to perform multiple moves instead of using a vector path.
template<> EIGEN_STRONG_INLINE void pstore1<Packet4f>(float* to, const float& a)
@@ -409,7 +858,7 @@
pstore(to, Packet2d(vec2d_swizzle1(pa,0,0)));
}
-#if EIGEN_COMP_PGI
+#if EIGEN_COMP_PGI && EIGEN_COMP_PGI < 1900
typedef const void * SsePrefetchPtrType;
#else
typedef const char * SsePrefetchPtrType;
@@ -437,32 +886,62 @@
template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return _mm_cvtsd_f64(a); }
template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { return _mm_cvtsi128_si32(a); }
#endif
+template<> EIGEN_STRONG_INLINE bool pfirst<Packet16b>(const Packet16b& a) { int x = _mm_cvtsi128_si32(a); return static_cast<bool>(x & 1); }
-template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
-{ return _mm_shuffle_ps(a,a,0x1B); }
-template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
-{ return _mm_shuffle_pd(a,a,0x1); }
-template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
-{ return _mm_shuffle_epi32(a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) { return _mm_shuffle_ps(a,a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { return _mm_shuffle_pd(a,a,0x1); }
+template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) { return _mm_shuffle_epi32(a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet16b preverse(const Packet16b& a) {
+#ifdef EIGEN_VECTORIZE_SSSE3
+ __m128i mask = _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
+ return _mm_shuffle_epi8(a, mask);
+#else
+ Packet16b tmp = _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 1, 2, 3));
+ tmp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(tmp, _MM_SHUFFLE(2, 3, 0, 1)), _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_or_si128(_mm_slli_epi16(tmp, 8), _mm_srli_epi16(tmp, 8));
+#endif
+}
-template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a)
-{
- const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF));
- return _mm_and_ps(a,mask);
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
+ return pfrexp_generic(a,exponent);
}
-template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a)
-{
- const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF));
- return _mm_and_pd(a,mask);
+
+// Extract exponent without existence of Packet2l.
+template<>
+EIGEN_STRONG_INLINE
+Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
+ const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
+ __m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
+ return _mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3));
}
-template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
-{
- #ifdef EIGEN_VECTORIZE_SSSE3
- return _mm_abs_epi32(a);
- #else
- Packet4i aux = _mm_srai_epi32(a,31);
- return _mm_sub_epi32(_mm_xor_si128(a,aux),aux);
- #endif
+
+template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
+ return pldexp_generic(a,exponent);
+}
+
+// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well
+// supported by SSE, and has more range than is needed for exponents.
+template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet2d max_exponent = pset1<Packet2d>(2099.0);
+ const Packet2d e = pmin(pmax(exponent, pnegate(max_exponent)), max_exponent);
+
+ // Convert e to integer and swizzle to low-order bits.
+ const Packet4i ei = vec4i_swizzle1(_mm_cvtpd_epi32(e), 0, 3, 1, 3);
+
+ // Split 2^e into four factors and multiply:
+ const Packet4i bias = _mm_set_epi32(0, 1023, 0, 1023);
+ Packet4i b = parithmetic_shift_right<2>(ei); // floor(e/4)
+ Packet2d c = _mm_castsi128_pd(_mm_slli_epi64(padd(b, bias), 52)); // 2^b
+ Packet2d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(ei, b), b), b); // e - 3b
+ c = _mm_castsi128_pd(_mm_slli_epi64(padd(b, bias), 52)); // 2^(e - 3b)
+ out = pmul(out, c); // a * 2^e
+ return out;
}
// with AVX, the default implementations based on pload1 are faster
@@ -505,38 +984,6 @@
vecs[0] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0x00));
}
-#ifdef EIGEN_VECTORIZE_SSE3
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- return _mm_hadd_ps(_mm_hadd_ps(vecs[0], vecs[1]),_mm_hadd_ps(vecs[2], vecs[3]));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- return _mm_hadd_pd(vecs[0], vecs[1]);
-}
-
-#else
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- Packet4f tmp0, tmp1, tmp2;
- tmp0 = _mm_unpacklo_ps(vecs[0], vecs[1]);
- tmp1 = _mm_unpackhi_ps(vecs[0], vecs[1]);
- tmp2 = _mm_unpackhi_ps(vecs[2], vecs[3]);
- tmp0 = _mm_add_ps(tmp0, tmp1);
- tmp1 = _mm_unpacklo_ps(vecs[2], vecs[3]);
- tmp1 = _mm_add_ps(tmp1, tmp2);
- tmp2 = _mm_movehl_ps(tmp1, tmp0);
- tmp0 = _mm_movelh_ps(tmp0, tmp1);
- return _mm_add_ps(tmp0, tmp2);
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- return _mm_add_pd(_mm_unpacklo_pd(vecs[0], vecs[1]), _mm_unpackhi_pd(vecs[0], vecs[1]));
-}
-#endif // SSE3
-
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
// Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures
@@ -562,38 +1009,28 @@
}
#ifdef EIGEN_VECTORIZE_SSSE3
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
-{
- return _mm_hadd_epi32(_mm_hadd_epi32(vecs[0], vecs[1]),_mm_hadd_epi32(vecs[2], vecs[3]));
-}
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
{
Packet4i tmp0 = _mm_hadd_epi32(a,a);
return pfirst<Packet4i>(_mm_hadd_epi32(tmp0,tmp0));
}
+
#else
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
{
Packet4i tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a,a));
return pfirst(tmp) + pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1));
}
-
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
-{
- Packet4i tmp0, tmp1, tmp2;
- tmp0 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
- tmp1 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
- tmp2 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
- tmp0 = _mm_add_epi32(tmp0, tmp1);
- tmp1 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
- tmp1 = _mm_add_epi32(tmp1, tmp2);
- tmp2 = _mm_unpacklo_epi64(tmp0, tmp1);
- tmp0 = _mm_unpackhi_epi64(tmp0, tmp1);
- return _mm_add_epi32(tmp0, tmp2);
-}
#endif
+
+template<> EIGEN_STRONG_INLINE bool predux<Packet16b>(const Packet16b& a) {
+ Packet4i tmp = _mm_or_si128(a, _mm_unpackhi_epi64(a,a));
+ return (pfirst(tmp) != 0) || (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) != 0);
+}
+
// Other reduction functions:
+
// mul
template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
{
@@ -611,7 +1048,13 @@
// TODO try to call _mm_mul_epu32 directly
EIGEN_ALIGN16 int aux[4];
pstore(aux, a);
- return (aux[0] * aux[1]) * (aux[2] * aux[3]);;
+ return (aux[0] * aux[1]) * (aux[2] * aux[3]);
+}
+
+template<> EIGEN_STRONG_INLINE bool predux_mul<Packet16b>(const Packet16b& a) {
+ Packet4i tmp = _mm_and_si128(a, _mm_unpackhi_epi64(a,a));
+ return ((pfirst<Packet4i>(tmp) == 0x01010101) &&
+ (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) == 0x01010101));
}
// min
@@ -666,113 +1109,16 @@
#endif // EIGEN_VECTORIZE_SSE4_1
}
-#if EIGEN_COMP_GNUC
-// template <> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
+// not needed yet
+// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet4f& x)
// {
-// Packet4f res = b;
-// asm("mulps %[a], %[b] \n\taddps %[c], %[b]" : [b] "+x" (res) : [a] "x" (a), [c] "x" (c));
-// return res;
+// return _mm_movemask_ps(x) == 0xF;
// }
-// EIGEN_STRONG_INLINE Packet4i _mm_alignr_epi8(const Packet4i& a, const Packet4i& b, const int i)
-// {
-// Packet4i res = a;
-// asm("palignr %[i], %[a], %[b] " : [b] "+x" (res) : [a] "x" (a), [i] "i" (i));
-// return res;
-// }
-#endif
-#ifdef EIGEN_VECTORIZE_SSSE3
-// SSSE3 versions
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
- if (Offset!=0)
- first = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(second), _mm_castps_si128(first), Offset*4));
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
-{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
- if (Offset!=0)
- first = _mm_alignr_epi8(second,first, Offset*4);
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
-{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset==1)
- first = _mm_castsi128_pd(_mm_alignr_epi8(_mm_castpd_si128(second), _mm_castpd_si128(first), 8));
- }
-};
-#else
-// SSE2 versions
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
-{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
- if (Offset==1)
- {
- first = _mm_move_ss(first,second);
- first = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(first),0x39));
- }
- else if (Offset==2)
- {
- first = _mm_movehl_ps(first,first);
- first = _mm_movelh_ps(first,second);
- }
- else if (Offset==3)
- {
- first = _mm_move_ss(first,second);
- first = _mm_shuffle_ps(first,second,0x93);
- }
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
-{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
- if (Offset==1)
- {
- first = _mm_castps_si128(_mm_move_ss(_mm_castsi128_ps(first),_mm_castsi128_ps(second)));
- first = _mm_shuffle_epi32(first,0x39);
- }
- else if (Offset==2)
- {
- first = _mm_castps_si128(_mm_movehl_ps(_mm_castsi128_ps(first),_mm_castsi128_ps(first)));
- first = _mm_castps_si128(_mm_movelh_ps(_mm_castsi128_ps(first),_mm_castsi128_ps(second)));
- }
- else if (Offset==3)
- {
- first = _mm_castps_si128(_mm_move_ss(_mm_castsi128_ps(first),_mm_castsi128_ps(second)));
- first = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(first),_mm_castsi128_ps(second),0x93));
- }
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
-{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset==1)
- {
- first = _mm_castps_pd(_mm_movehl_ps(_mm_castpd_ps(first),_mm_castpd_ps(first)));
- first = _mm_castps_pd(_mm_movelh_ps(_mm_castpd_ps(first),_mm_castpd_ps(second)));
- }
- }
-};
-#endif
+ return _mm_movemask_ps(x) != 0x0;
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4f,4>& kernel) {
@@ -799,6 +1145,100 @@
kernel.packet[3] = _mm_unpackhi_epi64(T2, T3);
}
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16b,4>& kernel) {
+ __m128i T0 = _mm_unpacklo_epi8(kernel.packet[0], kernel.packet[1]);
+ __m128i T1 = _mm_unpackhi_epi8(kernel.packet[0], kernel.packet[1]);
+ __m128i T2 = _mm_unpacklo_epi8(kernel.packet[2], kernel.packet[3]);
+ __m128i T3 = _mm_unpackhi_epi8(kernel.packet[2], kernel.packet[3]);
+ kernel.packet[0] = _mm_unpacklo_epi16(T0, T2);
+ kernel.packet[1] = _mm_unpackhi_epi16(T0, T2);
+ kernel.packet[2] = _mm_unpacklo_epi16(T1, T3);
+ kernel.packet[3] = _mm_unpackhi_epi16(T1, T3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16b,16>& kernel) {
+ // If we number the elements in the input thus:
+ // kernel.packet[ 0] = {00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 0a, 0b, 0c, 0d, 0e, 0f}
+ // kernel.packet[ 1] = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1a, 1b, 1c, 1d, 1e, 1f}
+ // ...
+ // kernel.packet[15] = {f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, fa, fb, fc, fd, fe, ff},
+ //
+ // the desired output is:
+ // kernel.packet[ 0] = {00, 10, 20, 30, 40, 50, 60, 70, 80, 90, a0, b0, c0, d0, e0, f0}
+ // kernel.packet[ 1] = {01, 11, 21, 31, 41, 51, 61, 71, 81, 91, a1, b1, c1, d1, e1, f1}
+ // ...
+ // kernel.packet[15] = {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, af, bf, cf, df, ef, ff},
+ __m128i t0 = _mm_unpacklo_epi8(kernel.packet[0], kernel.packet[1]); // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17
+ __m128i t1 = _mm_unpackhi_epi8(kernel.packet[0], kernel.packet[1]); // 08 18 09 19 0a 1a 0b 1b 0c 1c 0d 1d 0e 1e 0f 1f
+ __m128i t2 = _mm_unpacklo_epi8(kernel.packet[2], kernel.packet[3]); // 20 30 21 31 22 32 ... 27 37
+ __m128i t3 = _mm_unpackhi_epi8(kernel.packet[2], kernel.packet[3]); // 28 38 29 39 2a 3a ... 2f 3f
+ __m128i t4 = _mm_unpacklo_epi8(kernel.packet[4], kernel.packet[5]); // 40 50 41 51 42 52 47 57
+ __m128i t5 = _mm_unpackhi_epi8(kernel.packet[4], kernel.packet[5]); // 48 58 49 59 4a 5a
+ __m128i t6 = _mm_unpacklo_epi8(kernel.packet[6], kernel.packet[7]);
+ __m128i t7 = _mm_unpackhi_epi8(kernel.packet[6], kernel.packet[7]);
+ __m128i t8 = _mm_unpacklo_epi8(kernel.packet[8], kernel.packet[9]);
+ __m128i t9 = _mm_unpackhi_epi8(kernel.packet[8], kernel.packet[9]);
+ __m128i ta = _mm_unpacklo_epi8(kernel.packet[10], kernel.packet[11]);
+ __m128i tb = _mm_unpackhi_epi8(kernel.packet[10], kernel.packet[11]);
+ __m128i tc = _mm_unpacklo_epi8(kernel.packet[12], kernel.packet[13]);
+ __m128i td = _mm_unpackhi_epi8(kernel.packet[12], kernel.packet[13]);
+ __m128i te = _mm_unpacklo_epi8(kernel.packet[14], kernel.packet[15]);
+ __m128i tf = _mm_unpackhi_epi8(kernel.packet[14], kernel.packet[15]);
+
+ __m128i s0 = _mm_unpacklo_epi16(t0, t2); // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33
+ __m128i s1 = _mm_unpackhi_epi16(t0, t2); // 04 14 24 34
+ __m128i s2 = _mm_unpacklo_epi16(t1, t3); // 08 18 28 38 ...
+ __m128i s3 = _mm_unpackhi_epi16(t1, t3); // 0c 1c 2c 3c ...
+ __m128i s4 = _mm_unpacklo_epi16(t4, t6); // 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73
+ __m128i s5 = _mm_unpackhi_epi16(t4, t6); // 44 54 64 74 ...
+ __m128i s6 = _mm_unpacklo_epi16(t5, t7);
+ __m128i s7 = _mm_unpackhi_epi16(t5, t7);
+ __m128i s8 = _mm_unpacklo_epi16(t8, ta);
+ __m128i s9 = _mm_unpackhi_epi16(t8, ta);
+ __m128i sa = _mm_unpacklo_epi16(t9, tb);
+ __m128i sb = _mm_unpackhi_epi16(t9, tb);
+ __m128i sc = _mm_unpacklo_epi16(tc, te);
+ __m128i sd = _mm_unpackhi_epi16(tc, te);
+ __m128i se = _mm_unpacklo_epi16(td, tf);
+ __m128i sf = _mm_unpackhi_epi16(td, tf);
+
+ __m128i u0 = _mm_unpacklo_epi32(s0, s4); // 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71
+ __m128i u1 = _mm_unpackhi_epi32(s0, s4); // 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73
+ __m128i u2 = _mm_unpacklo_epi32(s1, s5);
+ __m128i u3 = _mm_unpackhi_epi32(s1, s5);
+ __m128i u4 = _mm_unpacklo_epi32(s2, s6);
+ __m128i u5 = _mm_unpackhi_epi32(s2, s6);
+ __m128i u6 = _mm_unpacklo_epi32(s3, s7);
+ __m128i u7 = _mm_unpackhi_epi32(s3, s7);
+ __m128i u8 = _mm_unpacklo_epi32(s8, sc);
+ __m128i u9 = _mm_unpackhi_epi32(s8, sc);
+ __m128i ua = _mm_unpacklo_epi32(s9, sd);
+ __m128i ub = _mm_unpackhi_epi32(s9, sd);
+ __m128i uc = _mm_unpacklo_epi32(sa, se);
+ __m128i ud = _mm_unpackhi_epi32(sa, se);
+ __m128i ue = _mm_unpacklo_epi32(sb, sf);
+ __m128i uf = _mm_unpackhi_epi32(sb, sf);
+
+ kernel.packet[0] = _mm_unpacklo_epi64(u0, u8);
+ kernel.packet[1] = _mm_unpackhi_epi64(u0, u8);
+ kernel.packet[2] = _mm_unpacklo_epi64(u1, u9);
+ kernel.packet[3] = _mm_unpackhi_epi64(u1, u9);
+ kernel.packet[4] = _mm_unpacklo_epi64(u2, ua);
+ kernel.packet[5] = _mm_unpackhi_epi64(u2, ua);
+ kernel.packet[6] = _mm_unpacklo_epi64(u3, ub);
+ kernel.packet[7] = _mm_unpackhi_epi64(u3, ub);
+ kernel.packet[8] = _mm_unpacklo_epi64(u4, uc);
+ kernel.packet[9] = _mm_unpackhi_epi64(u4, uc);
+ kernel.packet[10] = _mm_unpacklo_epi64(u5, ud);
+ kernel.packet[11] = _mm_unpackhi_epi64(u5, ud);
+ kernel.packet[12] = _mm_unpacklo_epi64(u6, ue);
+ kernel.packet[13] = _mm_unpackhi_epi64(u6, ue);
+ kernel.packet[14] = _mm_unpacklo_epi64(u7, uf);
+ kernel.packet[15] = _mm_unpackhi_epi64(u7, uf);
+}
+
template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
const __m128i zero = _mm_setzero_si128();
const __m128i select = _mm_set_epi32(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
@@ -830,46 +1270,8 @@
#endif
}
-template<> EIGEN_STRONG_INLINE Packet4f pinsertfirst(const Packet4f& a, float b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_ps(a,pset1<Packet4f>(b),1);
-#else
- return _mm_move_ss(a, _mm_load_ss(&b));
-#endif
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d pinsertfirst(const Packet2d& a, double b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_pd(a,pset1<Packet2d>(b),1);
-#else
- return _mm_move_sd(a, _mm_load_sd(&b));
-#endif
-}
-
-template<> EIGEN_STRONG_INLINE Packet4f pinsertlast(const Packet4f& a, float b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_ps(a,pset1<Packet4f>(b),(1<<3));
-#else
- const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x0,0x0,0x0,0xFFFFFFFF));
- return _mm_or_ps(_mm_andnot_ps(mask, a), _mm_and_ps(mask, pset1<Packet4f>(b)));
-#endif
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d pinsertlast(const Packet2d& a, double b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_pd(a,pset1<Packet2d>(b),(1<<1));
-#else
- const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x0,0xFFFFFFFF,0xFFFFFFFF));
- return _mm_or_pd(_mm_andnot_pd(mask, a), _mm_and_pd(mask, pset1<Packet2d>(b)));
-#endif
-}
-
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
return ::fmaf(a,b,c);
}
@@ -878,11 +1280,219 @@
}
#endif
+
+// Packet math for Eigen::half
+// Disable the following code since it's broken on too many platforms / compilers.
+//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
+#if 0
+
+typedef struct {
+ __m64 x;
+} Packet4h;
+
+
+template<> struct is_arithmetic<Packet4h> { enum { value = true }; };
+
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet4h type;
+ // There is no half-size packet for Packet4h.
+ typedef Packet4h half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasConj = 0,
+ HasSetLinear = 0,
+ HasSqrt = 0,
+ HasRsqrt = 0,
+ HasExp = 0,
+ HasLog = 0,
+ HasBlend = 0
+ };
+};
+
+
+template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4h half; };
+
+template<> EIGEN_STRONG_INLINE Packet4h pset1<Packet4h>(const Eigen::half& from) {
+ Packet4h result;
+ result.x = _mm_set1_pi16(from.x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h>(const Packet4h& from) {
+ return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm_cvtsi64_si32(from.x)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pconj(const Packet4h& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha + hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha + hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha + hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha + hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h psub<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha - hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha - hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha - hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha - hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha * hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha * hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha * hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha * hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pdiv<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha / hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha / hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha / hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha / hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pload<Packet4h>(const Eigen::half* from) {
+ Packet4h result;
+ result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h ploadu<Packet4h>(const Eigen::half* from) {
+ Packet4h result;
+ result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4h& from) {
+ __int64_t r = _mm_cvtm64_si64(from.x);
+ *(reinterpret_cast<__int64_t*>(to)) = r;
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4h& from) {
+ __int64_t r = _mm_cvtm64_si64(from.x);
+ *(reinterpret_cast<__int64_t*>(to)) = r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h
+ploadquad<Packet4h>(const Eigen::half* from) {
+ return pset1<Packet4h>(*from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pgather<Eigen::half, Packet4h>(const Eigen::half* from, Index stride)
+{
+ Packet4h result;
+ result.x = _mm_set_pi16(from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4h>(Eigen::half* to, const Packet4h& from, Index stride)
+{
+ __int64_t a = _mm_cvtm64_si64(from.x);
+ to[stride*0].x = static_cast<unsigned short>(a);
+ to[stride*1].x = static_cast<unsigned short>(a >> 16);
+ to[stride*2].x = static_cast<unsigned short>(a >> 32);
+ to[stride*3].x = static_cast<unsigned short>(a >> 48);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet4h,4>& kernel) {
+ __m64 T0 = _mm_unpacklo_pi16(kernel.packet[0].x, kernel.packet[1].x);
+ __m64 T1 = _mm_unpacklo_pi16(kernel.packet[2].x, kernel.packet[3].x);
+ __m64 T2 = _mm_unpackhi_pi16(kernel.packet[0].x, kernel.packet[1].x);
+ __m64 T3 = _mm_unpackhi_pi16(kernel.packet[2].x, kernel.packet[3].x);
+
+ kernel.packet[0].x = _mm_unpacklo_pi32(T0, T1);
+ kernel.packet[1].x = _mm_unpackhi_pi32(T0, T1);
+ kernel.packet[2].x = _mm_unpacklo_pi32(T2, T3);
+ kernel.packet[3].x = _mm_unpackhi_pi32(T2, T3);
+}
+
+#endif
+
+
} // end namespace internal
} // end namespace Eigen
-#if EIGEN_COMP_PGI
+#if EIGEN_COMP_PGI && EIGEN_COMP_PGI < 1900
// PGI++ does not define the following intrinsics in C++ mode.
static inline __m128 _mm_castpd_ps (__m128d x) { return reinterpret_cast<__m128&>(x); }
static inline __m128i _mm_castpd_si128(__m128d x) { return reinterpret_cast<__m128i&>(x); }
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/TypeCasting.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/TypeCasting.h
index c6ca8c7..d2a0037 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/TypeCasting.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/arch/SSE/TypeCasting.h
@@ -69,6 +69,71 @@
return _mm_cvtps_pd(a);
}
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
+ return _mm_castps_si128(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
+ return _mm_castsi128_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d,Packet4i>(const Packet4i& a) {
+ return _mm_castsi128_pd(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet2d>(const Packet2d& a) {
+ return _mm_castpd_si128(a);
+}
+
+// Disable the following code since it's broken on too many platforms / compilers.
+//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
+#if 0
+
+template <>
+struct type_casting_traits<Eigen::half, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4h, Packet4f>(const Packet4h& a) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ Eigen::half h = raw_uint16_to_half(static_cast<unsigned short>(a64));
+ float f1 = static_cast<float>(h);
+ h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ float f2 = static_cast<float>(h);
+ h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ float f3 = static_cast<float>(h);
+ h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ float f4 = static_cast<float>(h);
+ return _mm_set_ps(f4, f3, f2, f1);
+}
+
+template <>
+struct type_casting_traits<float, Eigen::half> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4h pcast<Packet4f, Packet4h>(const Packet4f& a) {
+ EIGEN_ALIGN16 float aux[4];
+ pstore(aux, a);
+ Eigen::half h0(aux[0]);
+ Eigen::half h1(aux[1]);
+ Eigen::half h2(aux[2]);
+ Eigen::half h3(aux[3]);
+
+ Packet4h result;
+ result.x = _mm_set_pi16(h3.x, h2.x, h1.x, h0.x);
+ return result;
+}
+
+#endif
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/AssignmentFunctors.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/AssignmentFunctors.h
index 4153b87..bf64ef4 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/AssignmentFunctors.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/AssignmentFunctors.h
@@ -144,7 +144,7 @@
EIGEN_EMPTY_STRUCT_CTOR(swap_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(Scalar& a, const Scalar& b) const
{
-#ifdef __CUDACC__
+#ifdef EIGEN_GPUCC
// FIXME is there some kind of cuda::swap?
Scalar t=b; const_cast<Scalar&>(b)=a; a=t;
#else
@@ -157,7 +157,16 @@
struct functor_traits<swap_assign_op<Scalar> > {
enum {
Cost = 3 * NumTraits<Scalar>::ReadCost,
- PacketAccess = packet_traits<Scalar>::Vectorizable
+ PacketAccess =
+ #if defined(EIGEN_VECTORIZE_AVX) && EIGEN_COMP_CLANG && (EIGEN_COMP_CLANG<800 || defined(__apple_build_version__))
+ // This is a partial workaround for a bug in clang generating bad code
+ // when mixing 256/512 bits loads and 128 bits moves.
+ // See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1684
+ // https://bugs.llvm.org/show_bug.cgi?id=40815
+ 0
+ #else
+ packet_traits<Scalar>::Vectorizable
+ #endif
};
};
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/BinaryFunctors.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/BinaryFunctors.h
index 3eae6b8..63f09ab 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/BinaryFunctors.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/BinaryFunctors.h
@@ -39,32 +39,26 @@
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
#endif
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a + b; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a + b; }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::padd(a,b); }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type predux(const Packet& a) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
{ return internal::predux(a); }
};
template<typename LhsScalar,typename RhsScalar>
struct functor_traits<scalar_sum_op<LhsScalar,RhsScalar> > {
enum {
- Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2, // rough estimate!
+ Cost = (int(NumTraits<LhsScalar>::AddCost) + int(NumTraits<RhsScalar>::AddCost)) / 2, // rough estimate!
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasAdd && packet_traits<RhsScalar>::HasAdd
// TODO vectorize mixed sum
};
};
-/** \internal
- * \brief Template specialization to deprecate the summation of boolean expressions.
- * This is required to solve Bug 426.
- * \sa DenseBase::count(), DenseBase::any(), ArrayBase::cast(), MatrixBase::cast()
- */
-template<> struct scalar_sum_op<bool,bool> : scalar_sum_op<int,int> {
- EIGEN_DEPRECATED
- scalar_sum_op() {}
-};
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool scalar_sum_op<bool,bool>::operator() (const bool& a, const bool& b) const { return a || b; }
/** \internal
@@ -83,23 +77,27 @@
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
#endif
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a * b; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a * b; }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pmul(a,b); }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type predux(const Packet& a) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
{ return internal::predux_mul(a); }
};
template<typename LhsScalar,typename RhsScalar>
struct functor_traits<scalar_product_op<LhsScalar,RhsScalar> > {
enum {
- Cost = (NumTraits<LhsScalar>::MulCost + NumTraits<RhsScalar>::MulCost)/2, // rough estimate!
+ Cost = (int(NumTraits<LhsScalar>::MulCost) + int(NumTraits<RhsScalar>::MulCost))/2, // rough estimate!
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasMul && packet_traits<RhsScalar>::HasMul
// TODO vectorize mixed product
};
};
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool scalar_product_op<bool,bool>::operator() (const bool& a, const bool& b) const { return a && b; }
+
+
/** \internal
* \brief Template functor to compute the conjugate product of two scalars
*
@@ -116,11 +114,11 @@
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_conj_product_op>::ReturnType result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const
{ return conj_helper<LhsScalar,RhsScalar,Conj,false>().pmul(a,b); }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return conj_helper<Packet,Packet,Conj,false>().pmul(a,b); }
};
template<typename LhsScalar,typename RhsScalar>
@@ -136,21 +134,28 @@
*
* \sa class CwiseBinaryOp, MatrixBase::cwiseMin, class VectorwiseOp, MatrixBase::minCoeff()
*/
-template<typename LhsScalar,typename RhsScalar>
+template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
struct scalar_min_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_min_op>::ReturnType result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return numext::mini(a, b); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ return internal::pmin<NaNPropagation>(a, b);
+ }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
- { return internal::pmin(a,b); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
+ {
+ return internal::pmin<NaNPropagation>(a,b);
+ }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type predux(const Packet& a) const
- { return internal::predux_min(a); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
+ {
+ return internal::predux_min<NaNPropagation>(a);
+ }
};
-template<typename LhsScalar,typename RhsScalar>
-struct functor_traits<scalar_min_op<LhsScalar,RhsScalar> > {
+
+template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
+struct functor_traits<scalar_min_op<LhsScalar,RhsScalar, NaNPropagation> > {
enum {
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMin
@@ -162,21 +167,28 @@
*
* \sa class CwiseBinaryOp, MatrixBase::cwiseMax, class VectorwiseOp, MatrixBase::maxCoeff()
*/
-template<typename LhsScalar,typename RhsScalar>
-struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar>
+template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
+struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_max_op>::ReturnType result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return numext::maxi(a, b); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ return internal::pmax<NaNPropagation>(a,b);
+ }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
- { return internal::pmax(a,b); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
+ {
+ return internal::pmax<NaNPropagation>(a,b);
+ }
template<typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type predux(const Packet& a) const
- { return internal::predux_max(a); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
+ {
+ return internal::predux_max<NaNPropagation>(a);
+ }
};
-template<typename LhsScalar,typename RhsScalar>
-struct functor_traits<scalar_max_op<LhsScalar,RhsScalar> > {
+
+template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
+struct functor_traits<scalar_max_op<LhsScalar,RhsScalar, NaNPropagation> > {
enum {
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMax
@@ -253,7 +265,6 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;}
};
-
/** \internal
* \brief Template functor to compute the hypot of two \b positive \b and \b real scalars
*
@@ -287,6 +298,7 @@
/** \internal
* \brief Template functor to compute the pow of two scalars
+ * See the specification of pow in https://en.cppreference.com/w/cpp/numeric/math/pow
*/
template<typename Scalar, typename Exponent>
struct scalar_pow_op : binary_op_base<Scalar,Exponent>
@@ -301,16 +313,31 @@
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
#endif
+
EIGEN_DEVICE_FUNC
inline result_type operator() (const Scalar& a, const Exponent& b) const { return numext::pow(a, b); }
+
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ {
+ return generic_pow(a,b);
+ }
};
+
template<typename Scalar, typename Exponent>
struct functor_traits<scalar_pow_op<Scalar,Exponent> > {
- enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
+ enum {
+ Cost = 5 * NumTraits<Scalar>::MulCost,
+ PacketAccess = (!NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger &&
+ packet_traits<Scalar>::HasExp && packet_traits<Scalar>::HasLog &&
+ packet_traits<Scalar>::HasRound && packet_traits<Scalar>::HasCmp &&
+ // Temporarly disable packet access for half/bfloat16 until
+ // accuracy is improved.
+ !is_same<Scalar, half>::value && !is_same<Scalar, bfloat16>::value
+ )
+ };
};
-
-
//---------- non associative binary functors ----------
/** \internal
@@ -337,7 +364,7 @@
template<typename LhsScalar,typename RhsScalar>
struct functor_traits<scalar_difference_op<LhsScalar,RhsScalar> > {
enum {
- Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
+ Cost = (int(NumTraits<LhsScalar>::AddCost) + int(NumTraits<RhsScalar>::AddCost)) / 2,
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasSub && packet_traits<RhsScalar>::HasSub
};
};
@@ -382,11 +409,14 @@
struct scalar_boolean_and_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_and_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a, const bool& b) const { return a && b; }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pand(a,b); }
};
template<> struct functor_traits<scalar_boolean_and_op> {
enum {
Cost = NumTraits<bool>::AddCost,
- PacketAccess = false
+ PacketAccess = true
};
};
@@ -398,11 +428,14 @@
struct scalar_boolean_or_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_or_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a, const bool& b) const { return a || b; }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::por(a,b); }
};
template<> struct functor_traits<scalar_boolean_or_op> {
enum {
Cost = NumTraits<bool>::AddCost,
- PacketAccess = false
+ PacketAccess = true
};
};
@@ -414,11 +447,44 @@
struct scalar_boolean_xor_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_xor_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a, const bool& b) const { return a ^ b; }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pxor(a,b); }
};
template<> struct functor_traits<scalar_boolean_xor_op> {
enum {
Cost = NumTraits<bool>::AddCost,
- PacketAccess = false
+ PacketAccess = true
+ };
+};
+
+/** \internal
+ * \brief Template functor to compute the absolute difference of two scalars
+ *
+ * \sa class CwiseBinaryOp, MatrixBase::absolute_difference
+ */
+template<typename LhsScalar,typename RhsScalar>
+struct scalar_absolute_difference_op : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_absolute_difference_op>::ReturnType result_type;
+#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_absolute_difference_op)
+#else
+ scalar_absolute_difference_op() {
+ EIGEN_SCALAR_BINARY_OP_PLUGIN
+ }
+#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const
+ { return numext::absdiff(a,b); }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pabsdiff(a,b); }
+};
+template<typename LhsScalar,typename RhsScalar>
+struct functor_traits<scalar_absolute_difference_op<LhsScalar,RhsScalar> > {
+ enum {
+ Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
+ PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasAbsDiff
};
};
@@ -436,7 +502,7 @@
typedef typename BinaryOp::second_argument_type second_argument_type;
typedef typename BinaryOp::result_type result_type;
- bind1st_op(const first_argument_type &val) : m_value(val) {}
+ EIGEN_DEVICE_FUNC explicit bind1st_op(const first_argument_type &val) : m_value(val) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const second_argument_type& b) const { return BinaryOp::operator()(m_value,b); }
@@ -455,7 +521,7 @@
typedef typename BinaryOp::second_argument_type second_argument_type;
typedef typename BinaryOp::result_type result_type;
- bind2nd_op(const second_argument_type &val) : m_value(val) {}
+ EIGEN_DEVICE_FUNC explicit bind2nd_op(const second_argument_type &val) : m_value(val) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const first_argument_type& a) const { return BinaryOp::operator()(a,m_value); }
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/NullaryFunctors.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/NullaryFunctors.h
index b03be02..192f225 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/NullaryFunctors.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/NullaryFunctors.h
@@ -37,26 +37,27 @@
struct functor_traits<scalar_identity_op<Scalar> >
{ enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = false, IsRepeatable = true }; };
-template <typename Scalar, typename Packet, bool IsInteger> struct linspaced_op_impl;
+template <typename Scalar, bool IsInteger> struct linspaced_op_impl;
-template <typename Scalar, typename Packet>
-struct linspaced_op_impl<Scalar,Packet,/*IsInteger*/false>
+template <typename Scalar>
+struct linspaced_op_impl<Scalar,/*IsInteger*/false>
{
- linspaced_op_impl(const Scalar& low, const Scalar& high, Index num_steps) :
- m_low(low), m_high(high), m_size1(num_steps==1 ? 1 : num_steps-1), m_step(num_steps==1 ? Scalar() : (high-low)/Scalar(num_steps-1)),
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+
+ EIGEN_DEVICE_FUNC linspaced_op_impl(const Scalar& low, const Scalar& high, Index num_steps) :
+ m_low(low), m_high(high), m_size1(num_steps==1 ? 1 : num_steps-1), m_step(num_steps==1 ? Scalar() : Scalar((high-low)/RealScalar(num_steps-1))),
m_flip(numext::abs(high)<numext::abs(low))
{}
template<typename IndexType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (IndexType i) const {
- typedef typename NumTraits<Scalar>::Real RealScalar;
if(m_flip)
- return (i==0)? m_low : (m_high - RealScalar(m_size1-i)*m_step);
+ return (i==0)? m_low : Scalar(m_high - RealScalar(m_size1-i)*m_step);
else
- return (i==m_size1)? m_high : (m_low + RealScalar(i)*m_step);
+ return (i==m_size1)? m_high : Scalar(m_low + RealScalar(i)*m_step);
}
- template<typename IndexType>
+ template<typename Packet, typename IndexType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(IndexType i) const
{
// Principle:
@@ -65,17 +66,17 @@
{
Packet pi = plset<Packet>(Scalar(i-m_size1));
Packet res = padd(pset1<Packet>(m_high), pmul(pset1<Packet>(m_step), pi));
- if(i==0)
- res = pinsertfirst(res, m_low);
- return res;
+ if (EIGEN_PREDICT_TRUE(i != 0)) return res;
+ Packet mask = pcmp_lt(pset1<Packet>(0), plset<Packet>(0));
+ return pselect<Packet>(mask, res, pset1<Packet>(m_low));
}
else
{
Packet pi = plset<Packet>(Scalar(i));
Packet res = padd(pset1<Packet>(m_low), pmul(pset1<Packet>(m_step), pi));
- if(i==m_size1-unpacket_traits<Packet>::size+1)
- res = pinsertlast(res, m_high);
- return res;
+ if(EIGEN_PREDICT_TRUE(i != m_size1-unpacket_traits<Packet>::size+1)) return res;
+ Packet mask = pcmp_lt(plset<Packet>(0), pset1<Packet>(unpacket_traits<Packet>::size-1));
+ return pselect<Packet>(mask, res, pset1<Packet>(m_high));
}
}
@@ -86,10 +87,10 @@
const bool m_flip;
};
-template <typename Scalar, typename Packet>
-struct linspaced_op_impl<Scalar,Packet,/*IsInteger*/true>
+template <typename Scalar>
+struct linspaced_op_impl<Scalar,/*IsInteger*/true>
{
- linspaced_op_impl(const Scalar& low, const Scalar& high, Index num_steps) :
+ EIGEN_DEVICE_FUNC linspaced_op_impl(const Scalar& low, const Scalar& high, Index num_steps) :
m_low(low),
m_multiplier((high-low)/convert_index<Scalar>(num_steps<=1 ? 1 : num_steps-1)),
m_divisor(convert_index<Scalar>((high>=low?num_steps:-num_steps)+(high-low))/((numext::abs(high-low)+1)==0?1:(numext::abs(high-low)+1))),
@@ -115,8 +116,8 @@
// Forward declaration (we default to random access which does not really give
// us a speed gain when using packet access but it allows to use the functor in
// nested expressions).
-template <typename Scalar, typename PacketType> struct linspaced_op;
-template <typename Scalar, typename PacketType> struct functor_traits< linspaced_op<Scalar,PacketType> >
+template <typename Scalar> struct linspaced_op;
+template <typename Scalar> struct functor_traits< linspaced_op<Scalar> >
{
enum
{
@@ -126,9 +127,9 @@
IsRepeatable = true
};
};
-template <typename Scalar, typename PacketType> struct linspaced_op
+template <typename Scalar> struct linspaced_op
{
- linspaced_op(const Scalar& low, const Scalar& high, Index num_steps)
+ EIGEN_DEVICE_FUNC linspaced_op(const Scalar& low, const Scalar& high, Index num_steps)
: impl((num_steps==1 ? high : low),high,num_steps)
{}
@@ -136,11 +137,11 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (IndexType i) const { return impl(i); }
template<typename Packet,typename IndexType>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(IndexType i) const { return impl.packetOp(i); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(IndexType i) const { return impl.template packetOp<Packet>(i); }
// This proxy object handles the actual required temporaries and the different
// implementations (integer vs. floating point).
- const linspaced_op_impl<Scalar,PacketType,NumTraits<Scalar>::IsInteger> impl;
+ const linspaced_op_impl<Scalar,NumTraits<Scalar>::IsInteger> impl;
};
// Linear access is automatically determined from the operator() prototypes available for the given functor.
@@ -166,12 +167,12 @@
template<typename Scalar,typename IndexType>
struct has_binary_operator<scalar_identity_op<Scalar>,IndexType> { enum { value = 1}; };
-template<typename Scalar, typename PacketType,typename IndexType>
-struct has_nullary_operator<linspaced_op<Scalar,PacketType>,IndexType> { enum { value = 0}; };
-template<typename Scalar, typename PacketType,typename IndexType>
-struct has_unary_operator<linspaced_op<Scalar,PacketType>,IndexType> { enum { value = 1}; };
-template<typename Scalar, typename PacketType,typename IndexType>
-struct has_binary_operator<linspaced_op<Scalar,PacketType>,IndexType> { enum { value = 0}; };
+template<typename Scalar,typename IndexType>
+struct has_nullary_operator<linspaced_op<Scalar>,IndexType> { enum { value = 0}; };
+template<typename Scalar,typename IndexType>
+struct has_unary_operator<linspaced_op<Scalar>,IndexType> { enum { value = 1}; };
+template<typename Scalar,typename IndexType>
+struct has_binary_operator<linspaced_op<Scalar>,IndexType> { enum { value = 0}; };
template<typename Scalar,typename IndexType>
struct has_nullary_operator<scalar_random_op<Scalar>,IndexType> { enum { value = 1}; };
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/StlFunctors.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/StlFunctors.h
index 9c1d758..4570c9b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/StlFunctors.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/StlFunctors.h
@@ -12,6 +12,28 @@
namespace Eigen {
+// Portable replacements for certain functors.
+namespace numext {
+
+template<typename T = void>
+struct equal_to {
+ typedef bool result_type;
+ EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const {
+ return lhs == rhs;
+ }
+};
+
+template<typename T = void>
+struct not_equal_to {
+ typedef bool result_type;
+ EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const {
+ return lhs != rhs;
+ }
+};
+
+}
+
+
namespace internal {
// default functor traits for STL functors:
@@ -69,10 +91,18 @@
{ enum { Cost = 1, PacketAccess = false }; };
template<typename T>
+struct functor_traits<numext::equal_to<T> >
+ : functor_traits<std::equal_to<T> > {};
+
+template<typename T>
struct functor_traits<std::not_equal_to<T> >
{ enum { Cost = 1, PacketAccess = false }; };
-#if (__cplusplus < 201103L) && (EIGEN_COMP_MSVC <= 1900)
+template<typename T>
+struct functor_traits<numext::not_equal_to<T> >
+ : functor_traits<std::not_equal_to<T> > {};
+
+#if (EIGEN_COMP_CXXVER < 11)
// std::binder* are deprecated since c++11 and will be removed in c++17
template<typename T>
struct functor_traits<std::binder2nd<T> >
@@ -83,7 +113,7 @@
{ enum { Cost = functor_traits<T>::Cost, PacketAccess = false }; };
#endif
-#if (__cplusplus < 201703L) && (EIGEN_COMP_MSVC < 1910)
+#if (EIGEN_COMP_CXXVER < 17)
// std::unary_negate is deprecated since c++17 and will be removed in c++20
template<typename T>
struct functor_traits<std::unary_negate<T> >
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/UnaryFunctors.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/UnaryFunctors.h
index 2e6a00f..16136d1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/functors/UnaryFunctors.h
@@ -109,7 +109,7 @@
template<typename Scalar> struct scalar_conjugate_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_conjugate_op)
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { using numext::conj; return conj(a); }
+ EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::conj(a); }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { return internal::pconj(a); }
};
@@ -117,7 +117,15 @@
struct functor_traits<scalar_conjugate_op<Scalar> >
{
enum {
- Cost = NumTraits<Scalar>::IsComplex ? NumTraits<Scalar>::AddCost : 0,
+ Cost = 0,
+ // Yes the cost is zero even for complexes because in most cases for which
+ // the cost is used, conjugation turns to be a no-op. Some examples:
+ // cost(a*conj(b)) == cost(a*b)
+ // cost(a+conj(b)) == cost(a+b)
+ // <etc.
+ // If we don't set it to zero, then:
+ // A.conjugate().lazyProduct(B.conjugate())
+ // will bake its operands. We definitely don't want that!
PacketAccess = packet_traits<Scalar>::HasConj
};
};
@@ -130,7 +138,7 @@
template<typename Scalar> struct scalar_arg_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op)
typedef typename NumTraits<Scalar>::Real result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const Scalar& a) const { using numext::arg; return arg(a); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const Scalar& a) const { return numext::arg(a); }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const
{ return internal::parg(a); }
@@ -159,6 +167,44 @@
{ enum { Cost = is_same<Scalar, NewType>::value ? 0 : NumTraits<NewType>::AddCost, PacketAccess = false }; };
/** \internal
+ * \brief Template functor to arithmetically shift a scalar right by a number of bits
+ *
+ * \sa class CwiseUnaryOp, MatrixBase::shift_right()
+ */
+template<typename Scalar, int N>
+struct scalar_shift_right_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_shift_right_op)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const
+ { return a >> N; }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const
+ { return internal::parithmetic_shift_right<N>(a); }
+};
+template<typename Scalar, int N>
+struct functor_traits<scalar_shift_right_op<Scalar,N> >
+{ enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasShift }; };
+
+/** \internal
+ * \brief Template functor to logically shift a scalar left by a number of bits
+ *
+ * \sa class CwiseUnaryOp, MatrixBase::shift_left()
+ */
+template<typename Scalar, int N>
+struct scalar_shift_left_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_shift_left_op)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const
+ { return a << N; }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const
+ { return internal::plogical_shift_left<N>(a); }
+};
+template<typename Scalar, int N>
+struct functor_traits<scalar_shift_left_op<Scalar,N> >
+{ enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasShift }; };
+
+/** \internal
* \brief Template functor to extract the real part of a complex
*
* \sa class CwiseUnaryOp, MatrixBase::real()
@@ -264,6 +310,26 @@
/** \internal
*
+ * \brief Template functor to compute the exponential of a scalar - 1.
+ *
+ * \sa class CwiseUnaryOp, ArrayBase::expm1()
+ */
+template<typename Scalar> struct scalar_expm1_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_expm1_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::expm1(a); }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pexpm1(a); }
+};
+template <typename Scalar>
+struct functor_traits<scalar_expm1_op<Scalar> > {
+ enum {
+ PacketAccess = packet_traits<Scalar>::HasExpm1,
+ Cost = functor_traits<scalar_exp_op<Scalar> >::Cost // TODO measure cost of expm1
+ };
+};
+
+/** \internal
+ *
* \brief Template functor to compute the logarithm of a scalar
*
* \sa class CwiseUnaryOp, ArrayBase::log()
@@ -321,7 +387,7 @@
*/
template<typename Scalar> struct scalar_log10_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_log10_op)
- EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { EIGEN_USING_STD_MATH(log10) return log10(a); }
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { EIGEN_USING_STD(log10) return log10(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog10(a); }
};
@@ -330,6 +396,22 @@
{ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasLog10 }; };
/** \internal
+ *
+ * \brief Template functor to compute the base-2 logarithm of a scalar
+ *
+ * \sa class CwiseUnaryOp, Cwise::log2()
+ */
+template<typename Scalar> struct scalar_log2_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_log2_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * numext::log(a); }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog2(a); }
+};
+template<typename Scalar>
+struct functor_traits<scalar_log2_op<Scalar> >
+{ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasLog }; };
+
+/** \internal
* \brief Template functor to compute the square root of a scalar
* \sa class CwiseUnaryOp, Cwise::sqrt()
*/
@@ -356,13 +438,25 @@
};
};
+// Boolean specialization to eliminate -Wimplicit-conversion-floating-point-to-bool warnings.
+template<> struct scalar_sqrt_op<bool> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_op)
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline bool operator() (const bool& a) const { return a; }
+ template <typename Packet>
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return a; }
+};
+template <>
+struct functor_traits<scalar_sqrt_op<bool> > {
+ enum { Cost = 1, PacketAccess = packet_traits<bool>::Vectorizable };
+};
+
/** \internal
* \brief Template functor to compute the reciprocal square root of a scalar
* \sa class CwiseUnaryOp, Cwise::rsqrt()
*/
template<typename Scalar> struct scalar_rsqrt_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op)
- EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(1)/numext::sqrt(a); }
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::rsqrt(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); }
};
@@ -528,6 +622,23 @@
};
};
+#if EIGEN_HAS_CXX11_MATH
+/** \internal
+ * \brief Template functor to compute the atanh of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::atanh()
+ */
+template <typename Scalar>
+struct scalar_atanh_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::atanh(a); }
+};
+
+template <typename Scalar>
+struct functor_traits<scalar_atanh_op<Scalar> > {
+ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+#endif
+
/** \internal
* \brief Template functor to compute the sinh of a scalar
* \sa class CwiseUnaryOp, ArrayBase::sinh()
@@ -547,6 +658,23 @@
};
};
+#if EIGEN_HAS_CXX11_MATH
+/** \internal
+ * \brief Template functor to compute the asinh of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::asinh()
+ */
+template <typename Scalar>
+struct scalar_asinh_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::asinh(a); }
+};
+
+template <typename Scalar>
+struct functor_traits<scalar_asinh_op<Scalar> > {
+ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+#endif
+
/** \internal
* \brief Template functor to compute the cosh of a scalar
* \sa class CwiseUnaryOp, ArrayBase::cosh()
@@ -566,6 +694,23 @@
};
};
+#if EIGEN_HAS_CXX11_MATH
+/** \internal
+ * \brief Template functor to compute the acosh of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::acosh()
+ */
+template <typename Scalar>
+struct scalar_acosh_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::acosh(a); }
+};
+
+template <typename Scalar>
+struct functor_traits<scalar_acosh_op<Scalar> > {
+ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+#endif
+
/** \internal
* \brief Template functor to compute the inverse of a scalar
* \sa class CwiseUnaryOp, Cwise::inverse()
@@ -578,9 +723,13 @@
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
{ return internal::pdiv(pset1<Packet>(Scalar(1)),a); }
};
-template<typename Scalar>
-struct functor_traits<scalar_inverse_op<Scalar> >
-{ enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasDiv }; };
+template <typename Scalar>
+struct functor_traits<scalar_inverse_op<Scalar> > {
+ enum {
+ PacketAccess = packet_traits<Scalar>::HasDiv,
+ Cost = scalar_div_cost<Scalar, PacketAccess>::value
+ };
+};
/** \internal
* \brief Template functor to compute the square of a scalar
@@ -598,6 +747,19 @@
struct functor_traits<scalar_square_op<Scalar> >
{ enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasMul }; };
+// Boolean specialization to avoid -Wint-in-bool-context warnings on GCC.
+template<>
+struct scalar_square_op<bool> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_square_op)
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline bool operator() (const bool& a) const { return a; }
+ template<typename Packet>
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
+ { return a; }
+};
+template<>
+struct functor_traits<scalar_square_op<bool> >
+{ enum { Cost = 0, PacketAccess = packet_traits<bool>::Vectorizable }; };
+
/** \internal
* \brief Template functor to compute the cube of a scalar
* \sa class CwiseUnaryOp, Cwise::cube()
@@ -614,6 +776,19 @@
struct functor_traits<scalar_cube_op<Scalar> >
{ enum { Cost = 2*NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasMul }; };
+// Boolean specialization to avoid -Wint-in-bool-context warnings on GCC.
+template<>
+struct scalar_cube_op<bool> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cube_op)
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline bool operator() (const bool& a) const { return a; }
+ template<typename Packet>
+ EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
+ { return a; }
+};
+template<>
+struct functor_traits<scalar_cube_op<bool> >
+{ enum { Cost = 0, PacketAccess = packet_traits<bool>::Vectorizable }; };
+
/** \internal
* \brief Template functor to compute the rounded value of a scalar
* \sa class CwiseUnaryOp, ArrayBase::round()
@@ -653,6 +828,25 @@
};
/** \internal
+ * \brief Template functor to compute the rounded (with current rounding mode) value of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::rint()
+ */
+template<typename Scalar> struct scalar_rint_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::rint(a); }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::print(a); }
+};
+template<typename Scalar>
+struct functor_traits<scalar_rint_op<Scalar> >
+{
+ enum {
+ Cost = NumTraits<Scalar>::MulCost,
+ PacketAccess = packet_traits<Scalar>::HasRint
+ };
+};
+
+/** \internal
* \brief Template functor to compute the ceil of a scalar
* \sa class CwiseUnaryOp, ArrayBase::ceil()
*/
@@ -678,7 +872,13 @@
template<typename Scalar> struct scalar_isnan_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_isnan_op)
typedef bool result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { return (numext::isnan)(a); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
+#if defined(SYCL_DEVICE_ONLY)
+ return numext::isnan(a);
+#else
+ return (numext::isnan)(a);
+#endif
+ }
};
template<typename Scalar>
struct functor_traits<scalar_isnan_op<Scalar> >
@@ -696,7 +896,13 @@
template<typename Scalar> struct scalar_isinf_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_isinf_op)
typedef bool result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { return (numext::isinf)(a); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
+#if defined(SYCL_DEVICE_ONLY)
+ return numext::isinf(a);
+#else
+ return (numext::isinf)(a);
+#endif
+ }
};
template<typename Scalar>
struct functor_traits<scalar_isinf_op<Scalar> >
@@ -714,7 +920,13 @@
template<typename Scalar> struct scalar_isfinite_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_isfinite_op)
typedef bool result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { return (numext::isfinite)(a); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
+#if defined(SYCL_DEVICE_ONLY)
+ return numext::isfinite(a);
+#else
+ return (numext::isfinite)(a);
+#endif
+ }
};
template<typename Scalar>
struct functor_traits<scalar_isfinite_op<Scalar> >
@@ -746,9 +958,9 @@
* \brief Template functor to compute the signum of a scalar
* \sa class CwiseUnaryOp, Cwise::sign()
*/
-template<typename Scalar,bool iscpx=(NumTraits<Scalar>::IsComplex!=0) > struct scalar_sign_op;
+template<typename Scalar,bool is_complex=(NumTraits<Scalar>::IsComplex!=0), bool is_integer=(NumTraits<Scalar>::IsInteger!=0) > struct scalar_sign_op;
template<typename Scalar>
-struct scalar_sign_op<Scalar,false> {
+struct scalar_sign_op<Scalar, false, true> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
@@ -758,8 +970,21 @@
//template <typename Packet>
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
};
+
template<typename Scalar>
-struct scalar_sign_op<Scalar,true> {
+struct scalar_sign_op<Scalar, false, false> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
+ {
+ return (numext::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
+ }
+ //TODO
+ //template <typename Packet>
+ //EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
+};
+
+template<typename Scalar, bool is_integer>
+struct scalar_sign_op<Scalar,true, is_integer> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
@@ -768,7 +993,7 @@
if (aa==real_type(0))
return Scalar(0);
aa = real_type(1)/aa;
- return Scalar(real(a)*aa, imag(a)*aa );
+ return Scalar(a.real()*aa, a.imag()*aa );
}
//TODO
//template <typename Packet>
@@ -777,7 +1002,7 @@
template<typename Scalar>
struct functor_traits<scalar_sign_op<Scalar> >
{ enum {
- Cost =
+ Cost =
NumTraits<Scalar>::IsComplex
? ( 8*NumTraits<Scalar>::MulCost ) // roughly
: ( 3*NumTraits<Scalar>::AddCost),
@@ -785,6 +1010,120 @@
};
};
+/** \internal
+ * \brief Template functor to compute the logistic function of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::logistic()
+ */
+template <typename T>
+struct scalar_logistic_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
+ return packetOp(x);
+ }
+
+ template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Packet packetOp(const Packet& x) const {
+ const Packet one = pset1<Packet>(T(1));
+ return pdiv(one, padd(one, pexp(pnegate(x))));
+ }
+};
+
+#ifndef EIGEN_GPU_COMPILE_PHASE
+/** \internal
+ * \brief Template specialization of the logistic function for float.
+ *
+ * Uses just a 9/10-degree rational interpolant which
+ * interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range
+ * [-9, 18]. Below -9 we use the more accurate approximation
+ * 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 withing
+ * one ulp. The shifted logistic is interpolated because it was easier to
+ * make the fit converge.
+ *
+ */
+template <>
+struct scalar_logistic_op<float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
+ return packetOp(x);
+ }
+
+ template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Packet packetOp(const Packet& _x) const {
+ const Packet cutoff_lower = pset1<Packet>(-9.f);
+ const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower);
+ const bool any_small = predux_any(lt_mask);
+
+ // The upper cut-off is the smallest x for which the rational approximation evaluates to 1.
+ // Choosing this value saves us a few instructions clamping the results at the end.
+#ifdef EIGEN_VECTORIZE_FMA
+ const Packet cutoff_upper = pset1<Packet>(15.7243833541870117f);
+#else
+ const Packet cutoff_upper = pset1<Packet>(15.6437711715698242f);
+#endif
+ const Packet x = pmin(_x, cutoff_upper);
+
+ // The monomial coefficients of the numerator polynomial (odd).
+ const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f);
+ const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f);
+ const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f);
+ const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f);
+ const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
+
+ // The monomial coefficients of the denominator polynomial (even).
+ const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f);
+ const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f);
+ const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f);
+ const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f);
+ const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f);
+ const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f);
+
+ // Since the polynomials are odd/even, we need x^2.
+ const Packet x2 = pmul(x, x);
+
+ // Evaluate the numerator polynomial p.
+ Packet p = pmadd(x2, alpha_9, alpha_7);
+ p = pmadd(x2, p, alpha_5);
+ p = pmadd(x2, p, alpha_3);
+ p = pmadd(x2, p, alpha_1);
+ p = pmul(x, p);
+
+ // Evaluate the denominator polynomial q.
+ Packet q = pmadd(x2, beta_10, beta_8);
+ q = pmadd(x2, q, beta_6);
+ q = pmadd(x2, q, beta_4);
+ q = pmadd(x2, q, beta_2);
+ q = pmadd(x2, q, beta_0);
+ // Divide the numerator by the denominator and shift it up.
+ const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f));
+ if (EIGEN_PREDICT_FALSE(any_small)) {
+ const Packet exponential = pexp(_x);
+ return pselect(lt_mask, exponential, logistic);
+ } else {
+ return logistic;
+ }
+ }
+};
+#endif // #ifndef EIGEN_GPU_COMPILE_PHASE
+
+template <typename T>
+struct functor_traits<scalar_logistic_op<T> > {
+ enum {
+ // The cost estimate for float here here is for the common(?) case where
+ // all arguments are greater than -9.
+ Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
+ (internal::is_same<T, float>::value
+ ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
+ : NumTraits<T>::AddCost * 2 +
+ functor_traits<scalar_exp_op<T> >::Cost),
+ PacketAccess =
+ packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
+ (internal::is_same<T, float>::value
+ ? packet_traits<T>::HasMul && packet_traits<T>::HasMax &&
+ packet_traits<T>::HasMin
+ : packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
+ };
+};
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index e3980f6..f35b760 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -15,7 +15,13 @@
namespace internal {
-template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false>
+enum GEBPPacketSizeType {
+ GEBPPacketFull = 0,
+ GEBPPacketHalf,
+ GEBPPacketQuarter
+};
+
+template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=GEBPPacketFull>
class gebp_traits;
@@ -25,16 +31,42 @@
return a<=0 ? b : a;
}
-#if EIGEN_ARCH_i386_OR_x86_64
-const std::ptrdiff_t defaultL1CacheSize = 32*1024;
-const std::ptrdiff_t defaultL2CacheSize = 256*1024;
-const std::ptrdiff_t defaultL3CacheSize = 2*1024*1024;
+#if defined(EIGEN_DEFAULT_L1_CACHE_SIZE)
+#define EIGEN_SET_DEFAULT_L1_CACHE_SIZE(val) EIGEN_DEFAULT_L1_CACHE_SIZE
#else
-const std::ptrdiff_t defaultL1CacheSize = 16*1024;
-const std::ptrdiff_t defaultL2CacheSize = 512*1024;
-const std::ptrdiff_t defaultL3CacheSize = 512*1024;
+#define EIGEN_SET_DEFAULT_L1_CACHE_SIZE(val) val
+#endif // defined(EIGEN_DEFAULT_L1_CACHE_SIZE)
+
+#if defined(EIGEN_DEFAULT_L2_CACHE_SIZE)
+#define EIGEN_SET_DEFAULT_L2_CACHE_SIZE(val) EIGEN_DEFAULT_L2_CACHE_SIZE
+#else
+#define EIGEN_SET_DEFAULT_L2_CACHE_SIZE(val) val
+#endif // defined(EIGEN_DEFAULT_L2_CACHE_SIZE)
+
+#if defined(EIGEN_DEFAULT_L3_CACHE_SIZE)
+#define EIGEN_SET_DEFAULT_L3_CACHE_SIZE(val) EIGEN_DEFAULT_L3_CACHE_SIZE
+#else
+#define EIGEN_SET_DEFAULT_L3_CACHE_SIZE(val) val
+#endif // defined(EIGEN_DEFAULT_L3_CACHE_SIZE)
+
+#if EIGEN_ARCH_i386_OR_x86_64
+const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(32*1024);
+const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(256*1024);
+const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(2*1024*1024);
+#elif EIGEN_ARCH_PPC
+const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(64*1024);
+const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(512*1024);
+const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(4*1024*1024);
+#else
+const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(16*1024);
+const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(512*1024);
+const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(512*1024);
#endif
+#undef EIGEN_SET_DEFAULT_L1_CACHE_SIZE
+#undef EIGEN_SET_DEFAULT_L2_CACHE_SIZE
+#undef EIGEN_SET_DEFAULT_L3_CACHE_SIZE
+
/** \internal */
struct CacheSizes {
CacheSizes(): m_l1(-1),m_l2(-1),m_l3(-1) {
@@ -50,7 +82,6 @@
std::ptrdiff_t m_l3;
};
-
/** \internal */
inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff_t* l2, std::ptrdiff_t* l3)
{
@@ -101,6 +132,16 @@
// at the register level. This small horizontal panel has to stay within L1 cache.
std::ptrdiff_t l1, l2, l3;
manage_caching_sizes(GetAction, &l1, &l2, &l3);
+ #ifdef EIGEN_VECTORIZE_AVX512
+ // We need to find a rationale for that, but without this adjustment,
+ // performance with AVX512 is pretty bad, like -20% slower.
+ // One reason is that with increasing packet-size, the blocking size k
+ // has to become pretty small if we want that 1 lhs panel fit within L1.
+ // For instance, with the 3pX4 kernel and double, the size of the lhs+rhs panels are:
+ // k*(3*64 + 4*8) Bytes, with l1=32kBytes, and k%8=0, we have k=144.
+ // This is quite small for a good reuse of the accumulation registers.
+ l1 *= 4;
+ #endif
if (num_threads > 1) {
typedef typename Traits::ResScalar ResScalar;
@@ -115,7 +156,8 @@
// registers. However once the latency is hidden there is no point in
// increasing the value of k, so we'll cap it at 320 (value determined
// experimentally).
- const Index k_cache = (numext::mini<Index>)((l1-ksub)/kdiv, 320);
+ // To avoid that k vanishes, we make k_cache at least as big as kr
+ const Index k_cache = numext::maxi<Index>(kr, (numext::mini<Index>)((l1-ksub)/kdiv, 320));
if (k_cache < k) {
k = k_cache - (k_cache % kr);
eigen_internal_assert(k > 0);
@@ -307,35 +349,60 @@
computeProductBlockingSizes<LhsScalar,RhsScalar,1,Index>(k, m, n, num_threads);
}
-#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
- #define CJMADD(CJ,A,B,C,T) C = CJ.pmadd(A,B,C);
-#else
+template <typename RhsPacket, typename RhsPacketx4, int registers_taken>
+struct RhsPanelHelper {
+ private:
+ static const int remaining_registers = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS - registers_taken;
+ public:
+ typedef typename conditional<remaining_registers>=4, RhsPacketx4, RhsPacket>::type type;
+};
- // FIXME (a bit overkill maybe ?)
+template <typename Packet>
+struct QuadPacket
+{
+ Packet B_0, B1, B2, B3;
+ const Packet& get(const FixedInt<0>&) const { return B_0; }
+ const Packet& get(const FixedInt<1>&) const { return B1; }
+ const Packet& get(const FixedInt<2>&) const { return B2; }
+ const Packet& get(const FixedInt<3>&) const { return B3; }
+};
- template<typename CJ, typename A, typename B, typename C, typename T> struct gebp_madd_selector {
- EIGEN_ALWAYS_INLINE static void run(const CJ& cj, A& a, B& b, C& c, T& /*t*/)
- {
- c = cj.pmadd(a,b,c);
- }
- };
+template <int N, typename T1, typename T2, typename T3>
+struct packet_conditional { typedef T3 type; };
- template<typename CJ, typename T> struct gebp_madd_selector<CJ,T,T,T,T> {
- EIGEN_ALWAYS_INLINE static void run(const CJ& cj, T& a, T& b, T& c, T& t)
- {
- t = b; t = cj.pmul(a,t); c = padd(c,t);
- }
- };
+template <typename T1, typename T2, typename T3>
+struct packet_conditional<GEBPPacketFull, T1, T2, T3> { typedef T1 type; };
- template<typename CJ, typename A, typename B, typename C, typename T>
- EIGEN_STRONG_INLINE void gebp_madd(const CJ& cj, A& a, B& b, C& c, T& t)
- {
- gebp_madd_selector<CJ,A,B,C,T>::run(cj,a,b,c,t);
- }
+template <typename T1, typename T2, typename T3>
+struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; };
- #define CJMADD(CJ,A,B,C,T) gebp_madd(CJ,A,B,C,T);
-// #define CJMADD(CJ,A,B,C,T) T = B; T = CJ.pmul(A,T); C = padd(C,T);
-#endif
+#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
+ typedef typename packet_conditional<packet_size, \
+ typename packet_traits<name ## Scalar>::type, \
+ typename packet_traits<name ## Scalar>::half, \
+ typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
+ prefix ## name ## Packet
+
+#define PACKET_DECL_COND(name, packet_size) \
+ typedef typename packet_conditional<packet_size, \
+ typename packet_traits<name ## Scalar>::type, \
+ typename packet_traits<name ## Scalar>::half, \
+ typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
+ name ## Packet
+
+#define PACKET_DECL_COND_SCALAR_PREFIX(prefix, packet_size) \
+ typedef typename packet_conditional<packet_size, \
+ typename packet_traits<Scalar>::type, \
+ typename packet_traits<Scalar>::half, \
+ typename unpacket_traits<typename packet_traits<Scalar>::half>::half>::type \
+ prefix ## ScalarPacket
+
+#define PACKET_DECL_COND_SCALAR(packet_size) \
+ typedef typename packet_conditional<packet_size, \
+ typename packet_traits<Scalar>::type, \
+ typename packet_traits<Scalar>::half, \
+ typename unpacket_traits<typename packet_traits<Scalar>::half>::half>::type \
+ ScalarPacket
/* Vectorization logic
* real*real: unpack rhs to constant packets, ...
@@ -347,7 +414,7 @@
* cplx*real : unpack rhs to constant packets, ...
* real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual
*/
-template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs, bool _ConjRhs>
+template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs, bool _ConjRhs, int Arch, int _PacketSize>
class gebp_traits
{
public:
@@ -355,13 +422,17 @@
typedef _RhsScalar RhsScalar;
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
+ PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
+
enum {
ConjLhs = _ConjLhs,
ConjRhs = _ConjRhs,
- Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable,
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
+ Vectorizable = unpacket_traits<_LhsPacket>::vectorizable && unpacket_traits<_RhsPacket>::vectorizable,
+ LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
+ RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
+ ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1,
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
@@ -370,10 +441,12 @@
// register block size along the M direction (currently, this one cannot be modified)
default_mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
-#if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) && !defined(EIGEN_VECTORIZE_ALTIVEC) && !defined(EIGEN_VECTORIZE_VSX)
- // we assume 16 registers
+#if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) && !defined(EIGEN_VECTORIZE_ALTIVEC) && !defined(EIGEN_VECTORIZE_VSX) \
+ && ((!EIGEN_COMP_MSVC) || (EIGEN_COMP_MSVC>=1914))
+ // we assume 16 registers or more
// See bug 992, if the scalar type is not vectorizable but that EIGEN_HAS_SINGLE_INSTRUCTION_MADD is defined,
// then using 3*LhsPacketSize triggers non-implemented paths in syrk.
+ // Bug 1515: MSVC prior to v19.14 yields to register spilling.
mr = Vectorizable ? 3*LhsPacketSize : default_mr,
#else
mr = default_mr,
@@ -383,37 +456,41 @@
RhsProgress = 1
};
- typedef typename packet_traits<LhsScalar>::type _LhsPacket;
- typedef typename packet_traits<RhsScalar>::type _RhsPacket;
- typedef typename packet_traits<ResScalar>::type _ResPacket;
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+ typedef LhsPacket LhsPacket4Packing;
+ typedef QuadPacket<RhsPacket> RhsPacketx4;
typedef ResPacket AccPacket;
EIGEN_STRONG_INLINE void initAcc(AccPacket& p)
{
p = pset1<ResPacket>(ResScalar(0));
}
-
- EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
- {
- pbroadcast4(b, b0, b1, b2, b3);
- }
-
-// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
-// {
-// pbroadcast2(b, b0, b1);
-// }
-
+
template<typename RhsPacketType>
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
{
dest = pset1<RhsPacketType>(*b);
}
-
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
+ }
+
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const
+ {
+ loadRhs(b, dest);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {
+ }
+
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
dest = ploadquad<RhsPacket>(b);
@@ -431,8 +508,8 @@
dest = ploadu<LhsPacketType>(a);
}
- template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
- EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, AccPacketType& tmp) const
+ template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const LaneIdType&) const
{
conj_helper<LhsPacketType,RhsPacketType,ConjLhs,ConjRhs> cj;
// It would be a lot cleaner to call pmadd all the time. Unfortunately if we
@@ -447,6 +524,12 @@
#endif
}
+ template<typename LhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const
+ {
+ madd(a, b.get(lane), c, tmp, lane);
+ }
+
EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
{
r = pmadd(c,alpha,r);
@@ -460,21 +543,25 @@
};
-template<typename RealScalar, bool _ConjLhs>
-class gebp_traits<std::complex<RealScalar>, RealScalar, _ConjLhs, false>
+template<typename RealScalar, bool _ConjLhs, int Arch, int _PacketSize>
+class gebp_traits<std::complex<RealScalar>, RealScalar, _ConjLhs, false, Arch, _PacketSize>
{
public:
typedef std::complex<RealScalar> LhsScalar;
typedef RealScalar RhsScalar;
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
+ PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
+
enum {
ConjLhs = _ConjLhs,
ConjRhs = false,
- Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable,
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
+ Vectorizable = unpacket_traits<_LhsPacket>::vectorizable && unpacket_traits<_RhsPacket>::vectorizable,
+ LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
+ RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
+ ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1,
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
nr = 4,
@@ -489,13 +576,12 @@
RhsProgress = 1
};
- typedef typename packet_traits<LhsScalar>::type _LhsPacket;
- typedef typename packet_traits<RhsScalar>::type _RhsPacket;
- typedef typename packet_traits<ResScalar>::type _ResPacket;
-
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+ typedef LhsPacket LhsPacket4Packing;
+
+ typedef QuadPacket<RhsPacket> RhsPacketx4;
typedef ResPacket AccPacket;
@@ -504,13 +590,42 @@
p = pset1<ResPacket>(ResScalar(0));
}
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
{
- dest = pset1<RhsPacket>(*b);
+ dest = pset1<RhsPacketType>(*b);
}
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
+ }
+
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const
+ {
+ loadRhs(b, dest);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
+ loadRhsQuad_impl(b,dest, typename conditional<RhsPacketSize==16,true_type,false_type>::type());
+ }
+
+ EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const true_type&) const
+ {
+ // FIXME we can do better!
+ // what we want here is a ploadheight
+ RhsScalar tmp[4] = {b[0],b[0],b[1],b[1]};
+ dest = ploadquad<RhsPacket>(tmp);
+ }
+
+ EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const false_type&) const
+ {
+ eigen_internal_assert(RhsPacketSize<=8);
dest = pset1<RhsPacket>(*b);
}
@@ -519,27 +634,20 @@
dest = pload<LhsPacket>(a);
}
- EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
+ template<typename LhsPacketType>
+ EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const
{
- dest = ploadu<LhsPacket>(a);
+ dest = ploadu<LhsPacketType>(a);
}
- EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
- {
- pbroadcast4(b, b0, b1, b2, b3);
- }
-
-// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
-// {
-// pbroadcast2(b, b0, b1);
-// }
-
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const LaneIdType&) const
{
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
}
- EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const true_type&) const
{
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
EIGEN_UNUSED_VARIABLE(tmp);
@@ -554,13 +662,20 @@
c += a * b;
}
- EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
+ template<typename LhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const
{
+ madd(a, b.get(lane), c, tmp, lane);
+ }
+
+ template <typename ResPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const
+ {
+ conj_helper<ResPacketType,ResPacketType,ConjLhs,false> cj;
r = cj.pmadd(c,alpha,r);
}
protected:
- conj_helper<ResPacket,ResPacket,ConjLhs,false> cj;
};
template<typename Packet>
@@ -579,13 +694,57 @@
return res;
}
+// note that for DoublePacket<RealPacket> the "4" in "downto4"
+// corresponds to the number of complexes, so it means "8"
+// it terms of real coefficients.
+
template<typename Packet>
-const DoublePacket<Packet>& predux_downto4(const DoublePacket<Packet> &a)
+const DoublePacket<Packet>&
+predux_half_dowto4(const DoublePacket<Packet> &a,
+ typename enable_if<unpacket_traits<Packet>::size<=8>::type* = 0)
{
return a;
}
-template<typename Packet> struct unpacket_traits<DoublePacket<Packet> > { typedef DoublePacket<Packet> half; };
+template<typename Packet>
+DoublePacket<typename unpacket_traits<Packet>::half>
+predux_half_dowto4(const DoublePacket<Packet> &a,
+ typename enable_if<unpacket_traits<Packet>::size==16>::type* = 0)
+{
+ // yes, that's pretty hackish :(
+ DoublePacket<typename unpacket_traits<Packet>::half> res;
+ typedef std::complex<typename unpacket_traits<Packet>::type> Cplx;
+ typedef typename packet_traits<Cplx>::type CplxPacket;
+ res.first = predux_half_dowto4(CplxPacket(a.first)).v;
+ res.second = predux_half_dowto4(CplxPacket(a.second)).v;
+ return res;
+}
+
+// same here, "quad" actually means "8" in terms of real coefficients
+template<typename Scalar, typename RealPacket>
+void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
+ typename enable_if<unpacket_traits<RealPacket>::size<=8>::type* = 0)
+{
+ dest.first = pset1<RealPacket>(numext::real(*b));
+ dest.second = pset1<RealPacket>(numext::imag(*b));
+}
+
+template<typename Scalar, typename RealPacket>
+void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
+ typename enable_if<unpacket_traits<RealPacket>::size==16>::type* = 0)
+{
+ // yes, that's pretty hackish too :(
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ RealScalar r[4] = {numext::real(b[0]), numext::real(b[0]), numext::real(b[1]), numext::real(b[1])};
+ RealScalar i[4] = {numext::imag(b[0]), numext::imag(b[0]), numext::imag(b[1]), numext::imag(b[1])};
+ dest.first = ploadquad<RealPacket>(r);
+ dest.second = ploadquad<RealPacket>(i);
+}
+
+
+template<typename Packet> struct unpacket_traits<DoublePacket<Packet> > {
+ typedef DoublePacket<typename unpacket_traits<Packet>::half> half;
+};
// template<typename Packet>
// DoublePacket<Packet> pmadd(const DoublePacket<Packet> &a, const DoublePacket<Packet> &b)
// {
@@ -595,8 +754,8 @@
// return res;
// }
-template<typename RealScalar, bool _ConjLhs, bool _ConjRhs>
-class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs >
+template<typename RealScalar, bool _ConjLhs, bool _ConjRhs, int Arch, int _PacketSize>
+class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs, Arch, _PacketSize >
{
public:
typedef std::complex<RealScalar> Scalar;
@@ -604,15 +763,21 @@
typedef std::complex<RealScalar> RhsScalar;
typedef std::complex<RealScalar> ResScalar;
+ PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
+ PACKET_DECL_COND(Real, _PacketSize);
+ PACKET_DECL_COND_SCALAR(_PacketSize);
+
enum {
ConjLhs = _ConjLhs,
ConjRhs = _ConjRhs,
- Vectorizable = packet_traits<RealScalar>::Vectorizable
- && packet_traits<Scalar>::Vectorizable,
- RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
+ Vectorizable = unpacket_traits<RealPacket>::vectorizable
+ && unpacket_traits<ScalarPacket>::vectorizable,
+ ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1,
+ LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
+ RhsPacketSize = Vectorizable ? unpacket_traits<RhsScalar>::size : 1,
+ RealPacketSize = Vectorizable ? unpacket_traits<RealPacket>::size : 1,
// FIXME: should depend on NumberOfRegisters
nr = 4,
@@ -622,14 +787,16 @@
RhsProgress = 1
};
- typedef typename packet_traits<RealScalar>::type RealPacket;
- typedef typename packet_traits<Scalar>::type ScalarPacket;
- typedef DoublePacket<RealPacket> DoublePacketType;
+ typedef DoublePacket<RealPacket> DoublePacketType;
+ typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type LhsPacket4Packing;
typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket;
typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type RhsPacket;
typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket;
typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type AccPacket;
+
+ // this actualy holds 8 packets!
+ typedef QuadPacket<RhsPacket> RhsPacketx4;
EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); }
@@ -640,17 +807,41 @@
}
// Scalar path
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ResPacket& dest) const
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ScalarPacket& dest) const
{
- dest = pset1<ResPacket>(*b);
+ dest = pset1<ScalarPacket>(*b);
}
// Vectorized path
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacketType& dest) const
+ template<typename RealPacketType>
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket<RealPacketType>& dest) const
{
- dest.first = pset1<RealPacket>(real(*b));
- dest.second = pset1<RealPacket>(imag(*b));
+ dest.first = pset1<RealPacketType>(numext::real(*b));
+ dest.second = pset1<RealPacketType>(numext::imag(*b));
}
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ loadRhs(b, dest.B_0);
+ loadRhs(b + 1, dest.B1);
+ loadRhs(b + 2, dest.B2);
+ loadRhs(b + 3, dest.B3);
+ }
+
+ // Scalar path
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, ScalarPacket& dest) const
+ {
+ loadRhs(b, dest);
+ }
+
+ // Vectorized path
+ template<typename RealPacketType>
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, DoublePacket<RealPacketType>& dest) const
+ {
+ loadRhs(b, dest);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const
{
@@ -658,33 +849,7 @@
}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const
{
- eigen_internal_assert(unpacket_traits<ScalarPacket>::size<=4);
- loadRhs(b,dest);
- }
-
- EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
- {
- // FIXME not sure that's the best way to implement it!
- loadRhs(b+0, b0);
- loadRhs(b+1, b1);
- loadRhs(b+2, b2);
- loadRhs(b+3, b3);
- }
-
- // Vectorized path
- EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacketType& b0, DoublePacketType& b1)
- {
- // FIXME not sure that's the best way to implement it!
- loadRhs(b+0, b0);
- loadRhs(b+1, b1);
- }
-
- // Scalar path
- EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsScalar& b0, RhsScalar& b1)
- {
- // FIXME not sure that's the best way to implement it!
- loadRhs(b+0, b0);
- loadRhs(b+1, b1);
+ loadQuadToDoublePacket(b,dest);
}
// nothing special here
@@ -693,47 +858,59 @@
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
}
- EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
+ template<typename LhsPacketType>
+ EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const
{
- dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
+ dest = ploadu<LhsPacketType>((const typename unpacket_traits<LhsPacketType>::type*)(a));
}
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacketType& c, RhsPacket& /*tmp*/) const
+ template<typename LhsPacketType, typename RhsPacketType, typename ResPacketType, typename TmpType, typename LaneIdType>
+ EIGEN_STRONG_INLINE
+ typename enable_if<!is_same<RhsPacketType,RhsPacketx4>::value>::type
+ madd(const LhsPacketType& a, const RhsPacketType& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/, const LaneIdType&) const
{
c.first = padd(pmul(a,b.first), c.first);
c.second = padd(pmul(a,b.second),c.second);
}
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/) const
+ template<typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/, const LaneIdType&) const
{
c = cj.pmadd(a,b,c);
}
+
+ template<typename LhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const
+ {
+ madd(a, b.get(lane), c, tmp, lane);
+ }
EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; }
- EIGEN_STRONG_INLINE void acc(const DoublePacketType& c, const ResPacket& alpha, ResPacket& r) const
+ template<typename RealPacketType, typename ResPacketType>
+ EIGEN_STRONG_INLINE void acc(const DoublePacket<RealPacketType>& c, const ResPacketType& alpha, ResPacketType& r) const
{
// assemble c
- ResPacket tmp;
+ ResPacketType tmp;
if((!ConjLhs)&&(!ConjRhs))
{
- tmp = pcplxflip(pconj(ResPacket(c.second)));
- tmp = padd(ResPacket(c.first),tmp);
+ tmp = pcplxflip(pconj(ResPacketType(c.second)));
+ tmp = padd(ResPacketType(c.first),tmp);
}
else if((!ConjLhs)&&(ConjRhs))
{
- tmp = pconj(pcplxflip(ResPacket(c.second)));
- tmp = padd(ResPacket(c.first),tmp);
+ tmp = pconj(pcplxflip(ResPacketType(c.second)));
+ tmp = padd(ResPacketType(c.first),tmp);
}
else if((ConjLhs)&&(!ConjRhs))
{
- tmp = pcplxflip(ResPacket(c.second));
- tmp = padd(pconj(ResPacket(c.first)),tmp);
+ tmp = pcplxflip(ResPacketType(c.second));
+ tmp = padd(pconj(ResPacketType(c.first)),tmp);
}
else if((ConjLhs)&&(ConjRhs))
{
- tmp = pcplxflip(ResPacket(c.second));
- tmp = psub(pconj(ResPacket(c.first)),tmp);
+ tmp = pcplxflip(ResPacketType(c.second));
+ tmp = psub(pconj(ResPacketType(c.first)),tmp);
}
r = pmadd(tmp,alpha,r);
@@ -743,8 +920,8 @@
conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj;
};
-template<typename RealScalar, bool _ConjRhs>
-class gebp_traits<RealScalar, std::complex<RealScalar>, false, _ConjRhs >
+template<typename RealScalar, bool _ConjRhs, int Arch, int _PacketSize>
+class gebp_traits<RealScalar, std::complex<RealScalar>, false, _ConjRhs, Arch, _PacketSize >
{
public:
typedef std::complex<RealScalar> Scalar;
@@ -752,14 +929,25 @@
typedef Scalar RhsScalar;
typedef Scalar ResScalar;
+ PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Real, _PacketSize);
+ PACKET_DECL_COND_SCALAR_PREFIX(_, _PacketSize);
+
+#undef PACKET_DECL_COND_SCALAR_PREFIX
+#undef PACKET_DECL_COND_PREFIX
+#undef PACKET_DECL_COND_SCALAR
+#undef PACKET_DECL_COND
+
enum {
ConjLhs = false,
ConjRhs = _ConjRhs,
- Vectorizable = packet_traits<RealScalar>::Vectorizable
- && packet_traits<Scalar>::Vectorizable,
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
+ Vectorizable = unpacket_traits<_RealPacket>::vectorizable
+ && unpacket_traits<_ScalarPacket>::vectorizable,
+ LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
+ RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
+ ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1,
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
// FIXME: should depend on NumberOfRegisters
@@ -770,14 +958,11 @@
RhsProgress = 1
};
- typedef typename packet_traits<LhsScalar>::type _LhsPacket;
- typedef typename packet_traits<RhsScalar>::type _RhsPacket;
- typedef typename packet_traits<ResScalar>::type _ResPacket;
-
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
-
+ typedef LhsPacket LhsPacket4Packing;
+ typedef QuadPacket<RhsPacket> RhsPacketx4;
typedef ResPacket AccPacket;
EIGEN_STRONG_INLINE void initAcc(AccPacket& p)
@@ -785,22 +970,25 @@
p = pset1<ResPacket>(ResScalar(0));
}
- EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
{
- dest = pset1<RhsPacket>(*b);
+ dest = pset1<RhsPacketType>(*b);
}
-
- void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
{
- pbroadcast4(b, b0, b1, b2, b3);
+ pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
}
-
-// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
-// {
-// // FIXME not sure that's the best way to implement it!
-// b0 = pload1<RhsPacket>(b+0);
-// b1 = pload1<RhsPacket>(b+1);
-// }
+
+ template<typename RhsPacketType>
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const
+ {
+ loadRhs(b, dest);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const
{
@@ -809,21 +997,23 @@
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
- eigen_internal_assert(unpacket_traits<RhsPacket>::size<=4);
- loadRhs(b,dest);
+ dest = ploadquad<RhsPacket>(b);
}
- EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
+ template<typename LhsPacketType>
+ EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const
{
- dest = ploaddup<LhsPacket>(a);
+ dest = ploaddup<LhsPacketType>(a);
}
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const LaneIdType&) const
{
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
}
- EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
+ template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const true_type&) const
{
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
EIGEN_UNUSED_VARIABLE(tmp);
@@ -839,16 +1029,24 @@
c += a * b;
}
- EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
+ template<typename LhsPacketType, typename AccPacketType, typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp, const LaneIdType& lane) const
{
+ madd(a, b.get(lane), c, tmp, lane);
+ }
+
+ template <typename ResPacketType, typename AccPacketType>
+ EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const
+ {
+ conj_helper<ResPacketType,ResPacketType,false,ConjRhs> cj;
r = cj.pmadd(alpha,c,r);
}
protected:
- conj_helper<ResPacket,ResPacket,false,ConjRhs> cj;
+
};
-/* optimized GEneral packed Block * packed Panel product kernel
+/* optimized General packed Block * packed Panel product kernel
*
* Mixing type logic: C += A * B
* | A | B | comments
@@ -858,26 +1056,47 @@
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel
{
- typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
+
typedef typename Traits::ResScalar ResScalar;
typedef typename Traits::LhsPacket LhsPacket;
typedef typename Traits::RhsPacket RhsPacket;
typedef typename Traits::ResPacket ResPacket;
typedef typename Traits::AccPacket AccPacket;
+ typedef typename Traits::RhsPacketx4 RhsPacketx4;
- typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
+ typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
+
+ typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
+
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
+ typedef typename HalfTraits::LhsPacket LhsPacketHalf;
+ typedef typename HalfTraits::RhsPacket RhsPacketHalf;
+ typedef typename HalfTraits::ResPacket ResPacketHalf;
+ typedef typename HalfTraits::AccPacket AccPacketHalf;
+
+ typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
+ typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
+ typedef typename QuarterTraits::ResPacket ResPacketQuarter;
+ typedef typename QuarterTraits::AccPacket AccPacketQuarter;
+
typedef typename DataMapper::LinearMapper LinearMapper;
enum {
Vectorizable = Traits::Vectorizable,
LhsProgress = Traits::LhsProgress,
+ LhsProgressHalf = HalfTraits::LhsProgress,
+ LhsProgressQuarter = QuarterTraits::LhsProgress,
RhsProgress = Traits::RhsProgress,
+ RhsProgressHalf = HalfTraits::RhsProgress,
+ RhsProgressQuarter = QuarterTraits::RhsProgress,
ResPacketSize = Traits::ResPacketSize
};
@@ -887,6 +1106,299 @@
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
};
+template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs,
+int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target>::LhsProgress>
+struct last_row_process_16_packets
+{
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
+ typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
+
+ typedef typename Traits::ResScalar ResScalar;
+ typedef typename SwappedTraits::LhsPacket SLhsPacket;
+ typedef typename SwappedTraits::RhsPacket SRhsPacket;
+ typedef typename SwappedTraits::ResPacket SResPacket;
+ typedef typename SwappedTraits::AccPacket SAccPacket;
+
+ EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA,
+ const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
+ ResScalar alpha, SAccPacket &C0)
+ {
+ EIGEN_UNUSED_VARIABLE(res);
+ EIGEN_UNUSED_VARIABLE(straits);
+ EIGEN_UNUSED_VARIABLE(blA);
+ EIGEN_UNUSED_VARIABLE(blB);
+ EIGEN_UNUSED_VARIABLE(depth);
+ EIGEN_UNUSED_VARIABLE(endk);
+ EIGEN_UNUSED_VARIABLE(i);
+ EIGEN_UNUSED_VARIABLE(j2);
+ EIGEN_UNUSED_VARIABLE(alpha);
+ EIGEN_UNUSED_VARIABLE(C0);
+ }
+};
+
+
+template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
+ typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
+ typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
+
+ typedef typename Traits::ResScalar ResScalar;
+ typedef typename SwappedTraits::LhsPacket SLhsPacket;
+ typedef typename SwappedTraits::RhsPacket SRhsPacket;
+ typedef typename SwappedTraits::ResPacket SResPacket;
+ typedef typename SwappedTraits::AccPacket SAccPacket;
+
+ EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA,
+ const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
+ ResScalar alpha, SAccPacket &C0)
+ {
+ typedef typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half SResPacketQuarter;
+ typedef typename unpacket_traits<typename unpacket_traits<SLhsPacket>::half>::half SLhsPacketQuarter;
+ typedef typename unpacket_traits<typename unpacket_traits<SRhsPacket>::half>::half SRhsPacketQuarter;
+ typedef typename unpacket_traits<typename unpacket_traits<SAccPacket>::half>::half SAccPacketQuarter;
+
+ SResPacketQuarter R = res.template gatherPacket<SResPacketQuarter>(i, j2);
+ SResPacketQuarter alphav = pset1<SResPacketQuarter>(alpha);
+
+ if (depth - endk > 0)
+ {
+ // We have to handle the last row(s) of the rhs, which
+ // correspond to a half-packet
+ SAccPacketQuarter c0 = predux_half_dowto4(predux_half_dowto4(C0));
+
+ for (Index kk = endk; kk < depth; kk++)
+ {
+ SLhsPacketQuarter a0;
+ SRhsPacketQuarter b0;
+ straits.loadLhsUnaligned(blB, a0);
+ straits.loadRhs(blA, b0);
+ straits.madd(a0,b0,c0,b0, fix<0>);
+ blB += SwappedTraits::LhsProgress/4;
+ blA += 1;
+ }
+ straits.acc(c0, alphav, R);
+ }
+ else
+ {
+ straits.acc(predux_half_dowto4(predux_half_dowto4(C0)), alphav, R);
+ }
+ res.scatterPacket(i, j2, R);
+ }
+};
+
+template<int nr, Index LhsProgress, Index RhsProgress, typename LhsScalar, typename RhsScalar, typename ResScalar, typename AccPacket, typename LhsPacket, typename RhsPacket, typename ResPacket, typename GEBPTraits, typename LinearMapper, typename DataMapper>
+struct lhs_process_one_packet
+{
+ typedef typename GEBPTraits::RhsPacketx4 RhsPacketx4;
+
+ EIGEN_STRONG_INLINE void peeled_kc_onestep(Index K, const LhsScalar* blA, const RhsScalar* blB, GEBPTraits traits, LhsPacket *A0, RhsPacketx4 *rhs_panel, RhsPacket *T0, AccPacket *C0, AccPacket *C1, AccPacket *C2, AccPacket *C3)
+ {
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1X4");
+ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!");
+ traits.loadLhs(&blA[(0+1*K)*LhsProgress], *A0);
+ traits.loadRhs(&blB[(0+4*K)*RhsProgress], *rhs_panel);
+ traits.madd(*A0, *rhs_panel, *C0, *T0, fix<0>);
+ traits.madd(*A0, *rhs_panel, *C1, *T0, fix<1>);
+ traits.madd(*A0, *rhs_panel, *C2, *T0, fix<2>);
+ traits.madd(*A0, *rhs_panel, *C3, *T0, fix<3>);
+ #if EIGEN_GNUC_AT_LEAST(6,0) && defined(EIGEN_VECTORIZE_SSE)
+ __asm__ ("" : "+x,m" (*A0));
+ #endif
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 1X4");
+ }
+
+ EIGEN_STRONG_INLINE void operator()(
+ const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, ResScalar alpha,
+ Index peelStart, Index peelEnd, Index strideA, Index strideB, Index offsetA, Index offsetB,
+ int prefetch_res_offset, Index peeled_kc, Index pk, Index cols, Index depth, Index packet_cols4)
+ {
+ GEBPTraits traits;
+
+ // loops on each largest micro horizontal panel of lhs
+ // (LhsProgress x depth)
+ for(Index i=peelStart; i<peelEnd; i+=LhsProgress)
+ {
+ // loops on each largest micro vertical panel of rhs (depth * nr)
+ for(Index j2=0; j2<packet_cols4; j2+=nr)
+ {
+ // We select a LhsProgress x nr micro block of res
+ // which is entirely stored into 1 x nr registers.
+
+ const LhsScalar* blA = &blockA[i*strideA+offsetA*(LhsProgress)];
+ prefetch(&blA[0]);
+
+ // gets res block as register
+ AccPacket C0, C1, C2, C3;
+ traits.initAcc(C0);
+ traits.initAcc(C1);
+ traits.initAcc(C2);
+ traits.initAcc(C3);
+ // To improve instruction pipelining, let's double the accumulation registers:
+ // even k will accumulate in C*, while odd k will accumulate in D*.
+ // This trick is crutial to get good performance with FMA, otherwise it is
+ // actually faster to perform separated MUL+ADD because of a naturally
+ // better instruction-level parallelism.
+ AccPacket D0, D1, D2, D3;
+ traits.initAcc(D0);
+ traits.initAcc(D1);
+ traits.initAcc(D2);
+ traits.initAcc(D3);
+
+ LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
+ LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
+ LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
+ LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
+
+ r0.prefetch(prefetch_res_offset);
+ r1.prefetch(prefetch_res_offset);
+ r2.prefetch(prefetch_res_offset);
+ r3.prefetch(prefetch_res_offset);
+
+ // performs "inner" products
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
+ prefetch(&blB[0]);
+ LhsPacket A0, A1;
+
+ for(Index k=0; k<peeled_kc; k+=pk)
+ {
+ EIGEN_ASM_COMMENT("begin gebp micro kernel 1/half/quarterX4");
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
+
+ internal::prefetch(blB+(48+0));
+ peeled_kc_onestep(0, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
+ peeled_kc_onestep(1, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
+ peeled_kc_onestep(2, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
+ peeled_kc_onestep(3, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
+ internal::prefetch(blB+(48+16));
+ peeled_kc_onestep(4, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
+ peeled_kc_onestep(5, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
+ peeled_kc_onestep(6, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
+ peeled_kc_onestep(7, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
+
+ blB += pk*4*RhsProgress;
+ blA += pk*LhsProgress;
+
+ EIGEN_ASM_COMMENT("end gebp micro kernel 1/half/quarterX4");
+ }
+ C0 = padd(C0,D0);
+ C1 = padd(C1,D1);
+ C2 = padd(C2,D2);
+ C3 = padd(C3,D3);
+
+ // process remaining peeled loop
+ for(Index k=peeled_kc; k<depth; k++)
+ {
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
+ peeled_kc_onestep(0, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
+ blB += 4*RhsProgress;
+ blA += LhsProgress;
+ }
+
+ ResPacket R0, R1;
+ ResPacket alphav = pset1<ResPacket>(alpha);
+
+ R0 = r0.template loadPacket<ResPacket>(0);
+ R1 = r1.template loadPacket<ResPacket>(0);
+ traits.acc(C0, alphav, R0);
+ traits.acc(C1, alphav, R1);
+ r0.storePacket(0, R0);
+ r1.storePacket(0, R1);
+
+ R0 = r2.template loadPacket<ResPacket>(0);
+ R1 = r3.template loadPacket<ResPacket>(0);
+ traits.acc(C2, alphav, R0);
+ traits.acc(C3, alphav, R1);
+ r2.storePacket(0, R0);
+ r3.storePacket(0, R1);
+ }
+
+ // Deal with remaining columns of the rhs
+ for(Index j2=packet_cols4; j2<cols; j2++)
+ {
+ // One column at a time
+ const LhsScalar* blA = &blockA[i*strideA+offsetA*(LhsProgress)];
+ prefetch(&blA[0]);
+
+ // gets res block as register
+ AccPacket C0;
+ traits.initAcc(C0);
+
+ LinearMapper r0 = res.getLinearMapper(i, j2);
+
+ // performs "inner" products
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB];
+ LhsPacket A0;
+
+ for(Index k= 0; k<peeled_kc; k+=pk)
+ {
+ EIGEN_ASM_COMMENT("begin gebp micro kernel 1/half/quarterX1");
+ RhsPacket B_0;
+
+#define EIGEN_GEBGP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1/half/quarterX1"); \
+ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
+ /* FIXME: why unaligned???? */ \
+ traits.loadLhsUnaligned(&blA[(0+1*K)*LhsProgress], A0); \
+ traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \
+ traits.madd(A0, B_0, C0, B_0, fix<0>); \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 1/half/quarterX1"); \
+ } while(false);
+
+ EIGEN_GEBGP_ONESTEP(0);
+ EIGEN_GEBGP_ONESTEP(1);
+ EIGEN_GEBGP_ONESTEP(2);
+ EIGEN_GEBGP_ONESTEP(3);
+ EIGEN_GEBGP_ONESTEP(4);
+ EIGEN_GEBGP_ONESTEP(5);
+ EIGEN_GEBGP_ONESTEP(6);
+ EIGEN_GEBGP_ONESTEP(7);
+
+ blB += pk*RhsProgress;
+ blA += pk*LhsProgress;
+
+ EIGEN_ASM_COMMENT("end gebp micro kernel 1/half/quarterX1");
+ }
+
+ // process remaining peeled loop
+ for(Index k=peeled_kc; k<depth; k++)
+ {
+ RhsPacket B_0;
+ EIGEN_GEBGP_ONESTEP(0);
+ blB += RhsProgress;
+ blA += LhsProgress;
+ }
+#undef EIGEN_GEBGP_ONESTEP
+ ResPacket R0;
+ ResPacket alphav = pset1<ResPacket>(alpha);
+ R0 = r0.template loadPacket<ResPacket>(0);
+ traits.acc(C0, alphav, R0);
+ r0.storePacket(0, R0);
+ }
+ }
+ }
+};
+
+template<int nr, Index LhsProgress, Index RhsProgress, typename LhsScalar, typename RhsScalar, typename ResScalar, typename AccPacket, typename LhsPacket, typename RhsPacket, typename ResPacket, typename GEBPTraits, typename LinearMapper, typename DataMapper>
+struct lhs_process_fraction_of_packet : lhs_process_one_packet<nr, LhsProgress, RhsProgress, LhsScalar, RhsScalar, ResScalar, AccPacket, LhsPacket, RhsPacket, ResPacket, GEBPTraits, LinearMapper, DataMapper>
+{
+
+EIGEN_STRONG_INLINE void peeled_kc_onestep(Index K, const LhsScalar* blA, const RhsScalar* blB, GEBPTraits traits, LhsPacket *A0, RhsPacket *B_0, RhsPacket *B1, RhsPacket *B2, RhsPacket *B3, AccPacket *C0, AccPacket *C1, AccPacket *C2, AccPacket *C3)
+ {
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1X4");
+ EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!");
+ traits.loadLhsUnaligned(&blA[(0+1*K)*(LhsProgress)], *A0);
+ traits.broadcastRhs(&blB[(0+4*K)*RhsProgress], *B_0, *B1, *B2, *B3);
+ traits.madd(*A0, *B_0, *C0, *B_0);
+ traits.madd(*A0, *B1, *C1, *B1);
+ traits.madd(*A0, *B2, *C2, *B2);
+ traits.madd(*A0, *B3, *C3, *B3);
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 1X4");
+ }
+};
+
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE
void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,ConjugateRhs>
@@ -903,10 +1415,12 @@
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
const Index peeled_mc3 = mr>=3*Traits::LhsProgress ? (rows/(3*LhsProgress))*(3*LhsProgress) : 0;
const Index peeled_mc2 = mr>=2*Traits::LhsProgress ? peeled_mc3+((rows-peeled_mc3)/(2*LhsProgress))*(2*LhsProgress) : 0;
- const Index peeled_mc1 = mr>=1*Traits::LhsProgress ? (rows/(1*LhsProgress))*(1*LhsProgress) : 0;
+ const Index peeled_mc1 = mr>=1*Traits::LhsProgress ? peeled_mc2+((rows-peeled_mc2)/(1*LhsProgress))*(1*LhsProgress) : 0;
+ const Index peeled_mc_half = mr>=LhsProgressHalf ? peeled_mc1+((rows-peeled_mc1)/(LhsProgressHalf))*(LhsProgressHalf) : 0;
+ const Index peeled_mc_quarter = mr>=LhsProgressQuarter ? peeled_mc_half+((rows-peeled_mc_half)/(LhsProgressQuarter))*(LhsProgressQuarter) : 0;
enum { pk = 8 }; // NOTE Such a large peeling factor is important for large matrices (~ +5% when >1000 on Haswell)
const Index peeled_kc = depth & ~(pk-1);
- const Index prefetch_res_offset = 32/sizeof(ResScalar);
+ const int prefetch_res_offset = 32/sizeof(ResScalar);
// const Index depth2 = depth & ~1;
//---------- Process 3 * LhsProgress rows at once ----------
@@ -964,36 +1478,48 @@
for(Index k=0; k<peeled_kc; k+=pk)
{
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4");
- RhsPacket B_0, T0;
+ // 15 registers are taken (12 for acc, 2 for lhs).
+ RhsPanel15 rhs_panel;
+ RhsPacket T0;
LhsPacket A2;
-
-#define EIGEN_GEBP_ONESTEP(K) \
- do { \
- EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX4"); \
+ #if EIGEN_COMP_GNUC_STRICT && EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && !(EIGEN_GNUC_AT_LEAST(9,0))
+ // see http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1633
+ // without this workaround A0, A1, and A2 are loaded in the same register,
+ // which is not good for pipelining
+ #define EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND __asm__ ("" : "+w,m" (A0), "+w,m" (A1), "+w,m" (A2));
+ #else
+ #define EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND
+ #endif
+#define EIGEN_GEBP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX4"); \
EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
- internal::prefetch(blA+(3*K+16)*LhsProgress); \
- if (EIGEN_ARCH_ARM) { internal::prefetch(blB+(4*K+16)*RhsProgress); } /* Bug 953 */ \
- traits.loadLhs(&blA[(0+3*K)*LhsProgress], A0); \
- traits.loadLhs(&blA[(1+3*K)*LhsProgress], A1); \
- traits.loadLhs(&blA[(2+3*K)*LhsProgress], A2); \
- traits.loadRhs(blB + (0+4*K)*Traits::RhsProgress, B_0); \
- traits.madd(A0, B_0, C0, T0); \
- traits.madd(A1, B_0, C4, T0); \
- traits.madd(A2, B_0, C8, B_0); \
- traits.loadRhs(blB + (1+4*K)*Traits::RhsProgress, B_0); \
- traits.madd(A0, B_0, C1, T0); \
- traits.madd(A1, B_0, C5, T0); \
- traits.madd(A2, B_0, C9, B_0); \
- traits.loadRhs(blB + (2+4*K)*Traits::RhsProgress, B_0); \
- traits.madd(A0, B_0, C2, T0); \
- traits.madd(A1, B_0, C6, T0); \
- traits.madd(A2, B_0, C10, B_0); \
- traits.loadRhs(blB + (3+4*K)*Traits::RhsProgress, B_0); \
- traits.madd(A0, B_0, C3 , T0); \
- traits.madd(A1, B_0, C7, T0); \
- traits.madd(A2, B_0, C11, B_0); \
- EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \
- } while(false)
+ internal::prefetch(blA + (3 * K + 16) * LhsProgress); \
+ if (EIGEN_ARCH_ARM || EIGEN_ARCH_MIPS) { \
+ internal::prefetch(blB + (4 * K + 16) * RhsProgress); \
+ } /* Bug 953 */ \
+ traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
+ traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
+ traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
+ EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND \
+ traits.loadRhs(blB + (0+4*K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
+ traits.madd(A1, rhs_panel, C4, T0, fix<0>); \
+ traits.madd(A2, rhs_panel, C8, T0, fix<0>); \
+ traits.updateRhs(blB + (1+4*K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
+ traits.madd(A1, rhs_panel, C5, T0, fix<1>); \
+ traits.madd(A2, rhs_panel, C9, T0, fix<1>); \
+ traits.updateRhs(blB + (2+4*K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
+ traits.madd(A1, rhs_panel, C6, T0, fix<2>); \
+ traits.madd(A2, rhs_panel, C10, T0, fix<2>); \
+ traits.updateRhs(blB + (3+4*K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
+ traits.madd(A1, rhs_panel, C7, T0, fix<3>); \
+ traits.madd(A2, rhs_panel, C11, T0, fix<3>); \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \
+ } while (false)
internal::prefetch(blB);
EIGEN_GEBP_ONESTEP(0);
@@ -1013,7 +1539,8 @@
// process remaining peeled loop
for(Index k=peeled_kc; k<depth; k++)
{
- RhsPacket B_0, T0;
+ RhsPanel15 rhs_panel;
+ RhsPacket T0;
LhsPacket A2;
EIGEN_GEBP_ONESTEP(0);
blB += 4*RhsProgress;
@@ -1025,9 +1552,9 @@
ResPacket R0, R1, R2;
ResPacket alphav = pset1<ResPacket>(alpha);
- R0 = r0.loadPacket(0 * Traits::ResPacketSize);
- R1 = r0.loadPacket(1 * Traits::ResPacketSize);
- R2 = r0.loadPacket(2 * Traits::ResPacketSize);
+ R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C0, alphav, R0);
traits.acc(C4, alphav, R1);
traits.acc(C8, alphav, R2);
@@ -1035,9 +1562,9 @@
r0.storePacket(1 * Traits::ResPacketSize, R1);
r0.storePacket(2 * Traits::ResPacketSize, R2);
- R0 = r1.loadPacket(0 * Traits::ResPacketSize);
- R1 = r1.loadPacket(1 * Traits::ResPacketSize);
- R2 = r1.loadPacket(2 * Traits::ResPacketSize);
+ R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C1, alphav, R0);
traits.acc(C5, alphav, R1);
traits.acc(C9, alphav, R2);
@@ -1045,9 +1572,9 @@
r1.storePacket(1 * Traits::ResPacketSize, R1);
r1.storePacket(2 * Traits::ResPacketSize, R2);
- R0 = r2.loadPacket(0 * Traits::ResPacketSize);
- R1 = r2.loadPacket(1 * Traits::ResPacketSize);
- R2 = r2.loadPacket(2 * Traits::ResPacketSize);
+ R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C2, alphav, R0);
traits.acc(C6, alphav, R1);
traits.acc(C10, alphav, R2);
@@ -1055,9 +1582,9 @@
r2.storePacket(1 * Traits::ResPacketSize, R1);
r2.storePacket(2 * Traits::ResPacketSize, R2);
- R0 = r3.loadPacket(0 * Traits::ResPacketSize);
- R1 = r3.loadPacket(1 * Traits::ResPacketSize);
- R2 = r3.loadPacket(2 * Traits::ResPacketSize);
+ R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C3, alphav, R0);
traits.acc(C7, alphav, R1);
traits.acc(C11, alphav, R2);
@@ -1093,20 +1620,20 @@
{
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX1");
RhsPacket B_0;
-#define EIGEN_GEBGP_ONESTEP(K) \
- do { \
- EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX1"); \
+#define EIGEN_GEBGP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX1"); \
EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
- traits.loadLhs(&blA[(0+3*K)*LhsProgress], A0); \
- traits.loadLhs(&blA[(1+3*K)*LhsProgress], A1); \
- traits.loadLhs(&blA[(2+3*K)*LhsProgress], A2); \
- traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \
- traits.madd(A0, B_0, C0, B_0); \
- traits.madd(A1, B_0, C4, B_0); \
- traits.madd(A2, B_0, C8, B_0); \
- EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \
- } while(false)
-
+ traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
+ traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
+ traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
+ traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
+ traits.madd(A0, B_0, C0, B_0, fix<0>); \
+ traits.madd(A1, B_0, C4, B_0, fix<0>); \
+ traits.madd(A2, B_0, C8, B_0, fix<0>); \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \
+ } while (false)
+
EIGEN_GEBGP_ONESTEP(0);
EIGEN_GEBGP_ONESTEP(1);
EIGEN_GEBGP_ONESTEP(2);
@@ -1116,8 +1643,8 @@
EIGEN_GEBGP_ONESTEP(6);
EIGEN_GEBGP_ONESTEP(7);
- blB += pk*RhsProgress;
- blA += pk*3*Traits::LhsProgress;
+ blB += int(pk) * int(RhsProgress);
+ blA += int(pk) * 3 * int(Traits::LhsProgress);
EIGEN_ASM_COMMENT("end gebp micro kernel 3pX1");
}
@@ -1134,9 +1661,9 @@
ResPacket R0, R1, R2;
ResPacket alphav = pset1<ResPacket>(alpha);
- R0 = r0.loadPacket(0 * Traits::ResPacketSize);
- R1 = r0.loadPacket(1 * Traits::ResPacketSize);
- R2 = r0.loadPacket(2 * Traits::ResPacketSize);
+ R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C0, alphav, R0);
traits.acc(C4, alphav, R1);
traits.acc(C8, alphav, R2);
@@ -1195,7 +1722,8 @@
for(Index k=0; k<peeled_kc; k+=pk)
{
EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX4");
- RhsPacket B_0, B1, B2, B3, T0;
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
// NOTE: the begin/end asm comments below work around bug 935!
// but they are not enough for gcc>=6 without FMA (bug 1637)
@@ -1204,24 +1732,24 @@
#else
#define EIGEN_GEBP_2PX4_SPILLING_WORKAROUND
#endif
- #define EIGEN_GEBGP_ONESTEP(K) \
- do { \
- EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX4"); \
- traits.loadLhs(&blA[(0+2*K)*LhsProgress], A0); \
- traits.loadLhs(&blA[(1+2*K)*LhsProgress], A1); \
- traits.broadcastRhs(&blB[(0+4*K)*RhsProgress], B_0, B1, B2, B3); \
- traits.madd(A0, B_0, C0, T0); \
- traits.madd(A1, B_0, C4, B_0); \
- traits.madd(A0, B1, C1, T0); \
- traits.madd(A1, B1, C5, B1); \
- traits.madd(A0, B2, C2, T0); \
- traits.madd(A1, B2, C6, B2); \
- traits.madd(A0, B3, C3, T0); \
- traits.madd(A1, B3, C7, B3); \
- EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \
- EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \
- } while(false)
-
+#define EIGEN_GEBGP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX4"); \
+ traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
+ traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
+ traits.loadRhs(&blB[(0 + 4 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
+ traits.madd(A1, rhs_panel, C4, T0, fix<0>); \
+ traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
+ traits.madd(A1, rhs_panel, C5, T0, fix<1>); \
+ traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
+ traits.madd(A1, rhs_panel, C6, T0, fix<2>); \
+ traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
+ traits.madd(A1, rhs_panel, C7, T0, fix<3>); \
+ EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \
+ } while (false)
+
internal::prefetch(blB+(48+0));
EIGEN_GEBGP_ONESTEP(0);
EIGEN_GEBGP_ONESTEP(1);
@@ -1241,7 +1769,8 @@
// process remaining peeled loop
for(Index k=peeled_kc; k<depth; k++)
{
- RhsPacket B_0, B1, B2, B3, T0;
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
EIGEN_GEBGP_ONESTEP(0);
blB += 4*RhsProgress;
blA += 2*Traits::LhsProgress;
@@ -1251,10 +1780,10 @@
ResPacket R0, R1, R2, R3;
ResPacket alphav = pset1<ResPacket>(alpha);
- R0 = r0.loadPacket(0 * Traits::ResPacketSize);
- R1 = r0.loadPacket(1 * Traits::ResPacketSize);
- R2 = r1.loadPacket(0 * Traits::ResPacketSize);
- R3 = r1.loadPacket(1 * Traits::ResPacketSize);
+ R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C0, alphav, R0);
traits.acc(C4, alphav, R1);
traits.acc(C1, alphav, R2);
@@ -1264,10 +1793,10 @@
r1.storePacket(0 * Traits::ResPacketSize, R2);
r1.storePacket(1 * Traits::ResPacketSize, R3);
- R0 = r2.loadPacket(0 * Traits::ResPacketSize);
- R1 = r2.loadPacket(1 * Traits::ResPacketSize);
- R2 = r3.loadPacket(0 * Traits::ResPacketSize);
- R3 = r3.loadPacket(1 * Traits::ResPacketSize);
+ R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C2, alphav, R0);
traits.acc(C6, alphav, R1);
traits.acc(C3, alphav, R2);
@@ -1312,8 +1841,8 @@
traits.loadLhs(&blA[(0+2*K)*LhsProgress], A0); \
traits.loadLhs(&blA[(1+2*K)*LhsProgress], A1); \
traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \
- traits.madd(A0, B_0, C0, B1); \
- traits.madd(A1, B_0, C4, B_0); \
+ traits.madd(A0, B_0, C0, B1, fix<0>); \
+ traits.madd(A1, B_0, C4, B_0, fix<0>); \
EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX1"); \
} while(false)
@@ -1326,8 +1855,8 @@
EIGEN_GEBGP_ONESTEP(6);
EIGEN_GEBGP_ONESTEP(7);
- blB += pk*RhsProgress;
- blA += pk*2*Traits::LhsProgress;
+ blB += int(pk) * int(RhsProgress);
+ blA += int(pk) * 2 * int(Traits::LhsProgress);
EIGEN_ASM_COMMENT("end gebp micro kernel 2pX1");
}
@@ -1344,8 +1873,8 @@
ResPacket R0, R1;
ResPacket alphav = pset1<ResPacket>(alpha);
- R0 = r0.loadPacket(0 * Traits::ResPacketSize);
- R1 = r0.loadPacket(1 * Traits::ResPacketSize);
+ R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C0, alphav, R0);
traits.acc(C4, alphav, R1);
r0.storePacket(0 * Traits::ResPacketSize, R0);
@@ -1357,186 +1886,43 @@
//---------- Process 1 * LhsProgress rows at once ----------
if(mr>=1*Traits::LhsProgress)
{
- // loops on each largest micro horizontal panel of lhs (1*LhsProgress x depth)
- for(Index i=peeled_mc2; i<peeled_mc1; i+=1*LhsProgress)
- {
- // loops on each largest micro vertical panel of rhs (depth * nr)
- for(Index j2=0; j2<packet_cols4; j2+=nr)
- {
- // We select a 1*Traits::LhsProgress x nr micro block of res which is entirely
- // stored into 1 x nr registers.
-
- const LhsScalar* blA = &blockA[i*strideA+offsetA*(1*Traits::LhsProgress)];
- prefetch(&blA[0]);
-
- // gets res block as register
- AccPacket C0, C1, C2, C3;
- traits.initAcc(C0);
- traits.initAcc(C1);
- traits.initAcc(C2);
- traits.initAcc(C3);
-
- LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
- LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
- LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
- LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
-
- r0.prefetch(prefetch_res_offset);
- r1.prefetch(prefetch_res_offset);
- r2.prefetch(prefetch_res_offset);
- r3.prefetch(prefetch_res_offset);
-
- // performs "inner" products
- const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
- prefetch(&blB[0]);
- LhsPacket A0;
-
- for(Index k=0; k<peeled_kc; k+=pk)
- {
- EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX4");
- RhsPacket B_0, B1, B2, B3;
-
-#define EIGEN_GEBGP_ONESTEP(K) \
- do { \
- EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX4"); \
- EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
- traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \
- traits.broadcastRhs(&blB[(0+4*K)*RhsProgress], B_0, B1, B2, B3); \
- traits.madd(A0, B_0, C0, B_0); \
- traits.madd(A0, B1, C1, B1); \
- traits.madd(A0, B2, C2, B2); \
- traits.madd(A0, B3, C3, B3); \
- EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX4"); \
- } while(false)
-
- internal::prefetch(blB+(48+0));
- EIGEN_GEBGP_ONESTEP(0);
- EIGEN_GEBGP_ONESTEP(1);
- EIGEN_GEBGP_ONESTEP(2);
- EIGEN_GEBGP_ONESTEP(3);
- internal::prefetch(blB+(48+16));
- EIGEN_GEBGP_ONESTEP(4);
- EIGEN_GEBGP_ONESTEP(5);
- EIGEN_GEBGP_ONESTEP(6);
- EIGEN_GEBGP_ONESTEP(7);
-
- blB += pk*4*RhsProgress;
- blA += pk*1*LhsProgress;
-
- EIGEN_ASM_COMMENT("end gebp micro kernel 1pX4");
- }
- // process remaining peeled loop
- for(Index k=peeled_kc; k<depth; k++)
- {
- RhsPacket B_0, B1, B2, B3;
- EIGEN_GEBGP_ONESTEP(0);
- blB += 4*RhsProgress;
- blA += 1*LhsProgress;
- }
-#undef EIGEN_GEBGP_ONESTEP
-
- ResPacket R0, R1;
- ResPacket alphav = pset1<ResPacket>(alpha);
-
- R0 = r0.loadPacket(0 * Traits::ResPacketSize);
- R1 = r1.loadPacket(0 * Traits::ResPacketSize);
- traits.acc(C0, alphav, R0);
- traits.acc(C1, alphav, R1);
- r0.storePacket(0 * Traits::ResPacketSize, R0);
- r1.storePacket(0 * Traits::ResPacketSize, R1);
-
- R0 = r2.loadPacket(0 * Traits::ResPacketSize);
- R1 = r3.loadPacket(0 * Traits::ResPacketSize);
- traits.acc(C2, alphav, R0);
- traits.acc(C3, alphav, R1);
- r2.storePacket(0 * Traits::ResPacketSize, R0);
- r3.storePacket(0 * Traits::ResPacketSize, R1);
- }
-
- // Deal with remaining columns of the rhs
- for(Index j2=packet_cols4; j2<cols; j2++)
- {
- // One column at a time
- const LhsScalar* blA = &blockA[i*strideA+offsetA*(1*Traits::LhsProgress)];
- prefetch(&blA[0]);
-
- // gets res block as register
- AccPacket C0;
- traits.initAcc(C0);
-
- LinearMapper r0 = res.getLinearMapper(i, j2);
-
- // performs "inner" products
- const RhsScalar* blB = &blockB[j2*strideB+offsetB];
- LhsPacket A0;
-
- for(Index k=0; k<peeled_kc; k+=pk)
- {
- EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX1");
- RhsPacket B_0;
-
-#define EIGEN_GEBGP_ONESTEP(K) \
- do { \
- EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX1"); \
- EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
- traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \
- traits.loadRhs(&blB[(0+K)*RhsProgress], B_0); \
- traits.madd(A0, B_0, C0, B_0); \
- EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX1"); \
- } while(false);
-
- EIGEN_GEBGP_ONESTEP(0);
- EIGEN_GEBGP_ONESTEP(1);
- EIGEN_GEBGP_ONESTEP(2);
- EIGEN_GEBGP_ONESTEP(3);
- EIGEN_GEBGP_ONESTEP(4);
- EIGEN_GEBGP_ONESTEP(5);
- EIGEN_GEBGP_ONESTEP(6);
- EIGEN_GEBGP_ONESTEP(7);
-
- blB += pk*RhsProgress;
- blA += pk*1*Traits::LhsProgress;
-
- EIGEN_ASM_COMMENT("end gebp micro kernel 1pX1");
- }
-
- // process remaining peeled loop
- for(Index k=peeled_kc; k<depth; k++)
- {
- RhsPacket B_0;
- EIGEN_GEBGP_ONESTEP(0);
- blB += RhsProgress;
- blA += 1*Traits::LhsProgress;
- }
-#undef EIGEN_GEBGP_ONESTEP
- ResPacket R0;
- ResPacket alphav = pset1<ResPacket>(alpha);
- R0 = r0.loadPacket(0 * Traits::ResPacketSize);
- traits.acc(C0, alphav, R0);
- r0.storePacket(0 * Traits::ResPacketSize, R0);
- }
- }
+ lhs_process_one_packet<nr, LhsProgress, RhsProgress, LhsScalar, RhsScalar, ResScalar, AccPacket, LhsPacket, RhsPacket, ResPacket, Traits, LinearMapper, DataMapper> p;
+ p(res, blockA, blockB, alpha, peeled_mc2, peeled_mc1, strideA, strideB, offsetA, offsetB, prefetch_res_offset, peeled_kc, pk, cols, depth, packet_cols4);
+ }
+ //---------- Process LhsProgressHalf rows at once ----------
+ if((LhsProgressHalf < LhsProgress) && mr>=LhsProgressHalf)
+ {
+ lhs_process_fraction_of_packet<nr, LhsProgressHalf, RhsProgressHalf, LhsScalar, RhsScalar, ResScalar, AccPacketHalf, LhsPacketHalf, RhsPacketHalf, ResPacketHalf, HalfTraits, LinearMapper, DataMapper> p;
+ p(res, blockA, blockB, alpha, peeled_mc1, peeled_mc_half, strideA, strideB, offsetA, offsetB, prefetch_res_offset, peeled_kc, pk, cols, depth, packet_cols4);
+ }
+ //---------- Process LhsProgressQuarter rows at once ----------
+ if((LhsProgressQuarter < LhsProgressHalf) && mr>=LhsProgressQuarter)
+ {
+ lhs_process_fraction_of_packet<nr, LhsProgressQuarter, RhsProgressQuarter, LhsScalar, RhsScalar, ResScalar, AccPacketQuarter, LhsPacketQuarter, RhsPacketQuarter, ResPacketQuarter, QuarterTraits, LinearMapper, DataMapper> p;
+ p(res, blockA, blockB, alpha, peeled_mc_half, peeled_mc_quarter, strideA, strideB, offsetA, offsetB, prefetch_res_offset, peeled_kc, pk, cols, depth, packet_cols4);
}
//---------- Process remaining rows, 1 at once ----------
- if(peeled_mc1<rows)
+ if(peeled_mc_quarter<rows)
{
// loop on each panel of the rhs
for(Index j2=0; j2<packet_cols4; j2+=nr)
{
// loop on each row of the lhs (1*LhsProgress x depth)
- for(Index i=peeled_mc1; i<rows; i+=1)
+ for(Index i=peeled_mc_quarter; i<rows; i+=1)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
- // The following piece of code wont work for 512 bit registers
- // Moreover, if LhsProgress==8 it assumes that there is a half packet of the same size
- // as nr (which is currently 4) for the return type.
+ // If LhsProgress is 8 or 16, it assumes that there is a
+ // half or quarter packet, respectively, of the same size as
+ // nr (which is currently 4) for the return type.
const int SResPacketHalfSize = unpacket_traits<typename unpacket_traits<SResPacket>::half>::size;
+ const int SResPacketQuarterSize = unpacket_traits<typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half>::size;
if ((SwappedTraits::LhsProgress % 4) == 0 &&
- (SwappedTraits::LhsProgress <= 8) &&
- (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr))
+ (SwappedTraits::LhsProgress<=16) &&
+ (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) &&
+ (SwappedTraits::LhsProgress!=16 || SResPacketQuarterSize==nr))
{
SAccPacket C0, C1, C2, C3;
straits.initAcc(C0);
@@ -1559,15 +1945,15 @@
straits.loadRhsQuad(blA+0*spk, B_0);
straits.loadRhsQuad(blA+1*spk, B_1);
- straits.madd(A0,B_0,C0,B_0);
- straits.madd(A1,B_1,C1,B_1);
+ straits.madd(A0,B_0,C0,B_0, fix<0>);
+ straits.madd(A1,B_1,C1,B_1, fix<0>);
straits.loadLhsUnaligned(blB+2*SwappedTraits::LhsProgress, A0);
straits.loadLhsUnaligned(blB+3*SwappedTraits::LhsProgress, A1);
straits.loadRhsQuad(blA+2*spk, B_0);
straits.loadRhsQuad(blA+3*spk, B_1);
- straits.madd(A0,B_0,C2,B_0);
- straits.madd(A1,B_1,C3,B_1);
+ straits.madd(A0,B_0,C2,B_0, fix<0>);
+ straits.madd(A1,B_1,C3,B_1, fix<0>);
blB += 4*SwappedTraits::LhsProgress;
blA += 4*spk;
@@ -1580,7 +1966,7 @@
straits.loadLhsUnaligned(blB, A0);
straits.loadRhsQuad(blA, B_0);
- straits.madd(A0,B_0,C0,B_0);
+ straits.madd(A0,B_0,C0,B_0, fix<0>);
blB += SwappedTraits::LhsProgress;
blA += spk;
@@ -1590,7 +1976,7 @@
// Special case where we have to first reduce the accumulation register C0
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
+ typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SRhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
@@ -1603,16 +1989,25 @@
SRhsPacketHalf b0;
straits.loadLhsUnaligned(blB, a0);
straits.loadRhs(blA, b0);
- SAccPacketHalf c0 = predux_downto4(C0);
- straits.madd(a0,b0,c0,b0);
+ SAccPacketHalf c0 = predux_half_dowto4(C0);
+ straits.madd(a0,b0,c0,b0, fix<0>);
straits.acc(c0, alphav, R);
}
else
{
- straits.acc(predux_downto4(C0), alphav, R);
+ straits.acc(predux_half_dowto4(C0), alphav, R);
}
res.scatterPacket(i, j2, R);
}
+ else if (SwappedTraits::LhsProgress==16)
+ {
+ // Special case where we have to first reduce the
+ // accumulation register C0. We specialize the block in
+ // template form, so that LhsProgress < 16 paths don't
+ // fail to compile
+ last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> p;
+ p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0);
+ }
else
{
SResPacket R = res.template gatherPacket<SResPacket>(i, j2);
@@ -1635,14 +2030,14 @@
B_0 = blB[0];
B_1 = blB[1];
- CJMADD(cj,A0,B_0,C0, B_0);
- CJMADD(cj,A0,B_1,C1, B_1);
-
+ C0 = cj.pmadd(A0,B_0,C0);
+ C1 = cj.pmadd(A0,B_1,C1);
+
B_0 = blB[2];
B_1 = blB[3];
- CJMADD(cj,A0,B_0,C2, B_0);
- CJMADD(cj,A0,B_1,C3, B_1);
-
+ C2 = cj.pmadd(A0,B_0,C2);
+ C3 = cj.pmadd(A0,B_1,C3);
+
blB += 4;
}
res(i, j2 + 0) += alpha * C0;
@@ -1656,7 +2051,7 @@
for(Index j2=packet_cols4; j2<cols; j2++)
{
// loop on each row of the lhs (1*LhsProgress x depth)
- for(Index i=peeled_mc1; i<rows; i+=1)
+ for(Index i=peeled_mc_quarter; i<rows; i+=1)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
@@ -1667,7 +2062,7 @@
{
LhsScalar A0 = blA[k];
RhsScalar B_0 = blB[k];
- CJMADD(cj, A0, B_0, C0, B_0);
+ C0 = cj.pmadd(A0, B_0, C0);
}
res(i, j2) += alpha * C0;
}
@@ -1676,8 +2071,6 @@
}
-#undef CJMADD
-
// pack a block of the lhs
// The traversal is as follow (mr==4):
// 0 4 8 12 ...
@@ -1692,19 +2085,24 @@
//
// 32 33 34 35 ...
// 36 36 38 39 ...
-template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode>
-struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode>
+template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
{
typedef typename DataMapper::LinearMapper LinearMapper;
EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
};
-template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode>
-EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode>
+template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
::operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
- typedef typename packet_traits<Scalar>::type Packet;
- enum { PacketSize = packet_traits<Scalar>::size };
+ typedef typename unpacket_traits<Packet>::half HalfPacket;
+ typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
+ enum { PacketSize = unpacket_traits<Packet>::size,
+ HalfPacketSize = unpacket_traits<HalfPacket>::size,
+ QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
+ HasHalf = (int)HalfPacketSize < (int)PacketSize,
+ HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize};
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK LHS");
EIGEN_UNUSED_VARIABLE(stride);
@@ -1716,9 +2114,12 @@
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
- const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
- const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1
- : Pack2>1 ? (rows/Pack2)*Pack2 : 0;
+ const Index peeled_mc1 = Pack1>=1*PacketSize ? peeled_mc2+((rows-peeled_mc2)/(1*PacketSize))*(1*PacketSize) : 0;
+ const Index peeled_mc_half = Pack1>=HalfPacketSize ? peeled_mc1+((rows-peeled_mc1)/(HalfPacketSize))*(HalfPacketSize) : 0;
+ const Index peeled_mc_quarter = Pack1>=QuarterPacketSize ? (rows/(QuarterPacketSize))*(QuarterPacketSize) : 0;
+ const Index last_lhs_progress = rows > peeled_mc_quarter ? (rows - peeled_mc_quarter) & ~1 : 0;
+ const Index peeled_mc0 = Pack2>=PacketSize ? peeled_mc_quarter
+ : Pack2>1 && last_lhs_progress ? (rows/last_lhs_progress)*last_lhs_progress : 0;
Index i=0;
@@ -1732,9 +2133,9 @@
for(Index k=0; k<depth; k++)
{
Packet A, B, C;
- A = lhs.loadPacket(i+0*PacketSize, k);
- B = lhs.loadPacket(i+1*PacketSize, k);
- C = lhs.loadPacket(i+2*PacketSize, k);
+ A = lhs.template loadPacket<Packet>(i+0*PacketSize, k);
+ B = lhs.template loadPacket<Packet>(i+1*PacketSize, k);
+ C = lhs.template loadPacket<Packet>(i+2*PacketSize, k);
pstore(blockA+count, cj.pconj(A)); count+=PacketSize;
pstore(blockA+count, cj.pconj(B)); count+=PacketSize;
pstore(blockA+count, cj.pconj(C)); count+=PacketSize;
@@ -1752,8 +2153,8 @@
for(Index k=0; k<depth; k++)
{
Packet A, B;
- A = lhs.loadPacket(i+0*PacketSize, k);
- B = lhs.loadPacket(i+1*PacketSize, k);
+ A = lhs.template loadPacket<Packet>(i+0*PacketSize, k);
+ B = lhs.template loadPacket<Packet>(i+1*PacketSize, k);
pstore(blockA+count, cj.pconj(A)); count+=PacketSize;
pstore(blockA+count, cj.pconj(B)); count+=PacketSize;
}
@@ -1770,27 +2171,67 @@
for(Index k=0; k<depth; k++)
{
Packet A;
- A = lhs.loadPacket(i+0*PacketSize, k);
+ A = lhs.template loadPacket<Packet>(i+0*PacketSize, k);
pstore(blockA+count, cj.pconj(A));
count+=PacketSize;
}
if(PanelMode) count += (1*PacketSize) * (stride-offset-depth);
}
}
- // Pack scalars
- if(Pack2<PacketSize && Pack2>1)
+ // Pack half packets
+ if(HasHalf && Pack1>=HalfPacketSize)
{
- for(; i<peeled_mc0; i+=Pack2)
+ for(; i<peeled_mc_half; i+=HalfPacketSize)
{
- if(PanelMode) count += Pack2 * offset;
+ if(PanelMode) count += (HalfPacketSize) * offset;
for(Index k=0; k<depth; k++)
- for(Index w=0; w<Pack2; w++)
- blockA[count++] = cj(lhs(i+w, k));
-
- if(PanelMode) count += Pack2 * (stride-offset-depth);
+ {
+ HalfPacket A;
+ A = lhs.template loadPacket<HalfPacket>(i+0*(HalfPacketSize), k);
+ pstoreu(blockA+count, cj.pconj(A));
+ count+=HalfPacketSize;
+ }
+ if(PanelMode) count += (HalfPacketSize) * (stride-offset-depth);
}
}
+ // Pack quarter packets
+ if(HasQuarter && Pack1>=QuarterPacketSize)
+ {
+ for(; i<peeled_mc_quarter; i+=QuarterPacketSize)
+ {
+ if(PanelMode) count += (QuarterPacketSize) * offset;
+
+ for(Index k=0; k<depth; k++)
+ {
+ QuarterPacket A;
+ A = lhs.template loadPacket<QuarterPacket>(i+0*(QuarterPacketSize), k);
+ pstoreu(blockA+count, cj.pconj(A));
+ count+=QuarterPacketSize;
+ }
+ if(PanelMode) count += (QuarterPacketSize) * (stride-offset-depth);
+ }
+ }
+ // Pack2 may be *smaller* than PacketSize—that happens for
+ // products like real * complex, where we have to go half the
+ // progress on the lhs in order to duplicate those operands to
+ // address both real & imaginary parts on the rhs. This portion will
+ // pack those half ones until they match the number expected on the
+ // last peeling loop at this point (for the rhs).
+ if(Pack2<PacketSize && Pack2>1)
+ {
+ for(; i<peeled_mc0; i+=last_lhs_progress)
+ {
+ if(PanelMode) count += last_lhs_progress * offset;
+
+ for(Index k=0; k<depth; k++)
+ for(Index w=0; w<last_lhs_progress; w++)
+ blockA[count++] = cj(lhs(i+w, k));
+
+ if(PanelMode) count += last_lhs_progress * (stride-offset-depth);
+ }
+ }
+ // Pack scalars
for(; i<rows; i++)
{
if(PanelMode) count += offset;
@@ -1800,19 +2241,24 @@
}
}
-template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode>
-struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, RowMajor, Conjugate, PanelMode>
+template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
{
typedef typename DataMapper::LinearMapper LinearMapper;
EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
};
-template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode>
-EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, RowMajor, Conjugate, PanelMode>
+template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
::operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
- typedef typename packet_traits<Scalar>::type Packet;
- enum { PacketSize = packet_traits<Scalar>::size };
+ typedef typename unpacket_traits<Packet>::half HalfPacket;
+ typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
+ enum { PacketSize = unpacket_traits<Packet>::size,
+ HalfPacketSize = unpacket_traits<HalfPacket>::size,
+ QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
+ HasHalf = (int)HalfPacketSize < (int)PacketSize,
+ HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize};
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK LHS");
EIGEN_UNUSED_VARIABLE(stride);
@@ -1820,37 +2266,51 @@
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
Index count = 0;
+ bool gone_half = false, gone_quarter = false, gone_last = false;
-// const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
-// const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
-// const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
-
- int pack = Pack1;
Index i = 0;
+ int pack = Pack1;
+ int psize = PacketSize;
while(pack>0)
{
Index remaining_rows = rows-i;
- Index peeled_mc = i+(remaining_rows/pack)*pack;
+ Index peeled_mc = gone_last ? Pack2>1 ? (rows/pack)*pack : 0 : i+(remaining_rows/pack)*pack;
+ Index starting_pos = i;
for(; i<peeled_mc; i+=pack)
{
if(PanelMode) count += pack * offset;
- const Index peeled_k = (depth/PacketSize)*PacketSize;
Index k=0;
- if(pack>=PacketSize)
+ if(pack>=psize && psize >= QuarterPacketSize)
{
- for(; k<peeled_k; k+=PacketSize)
+ const Index peeled_k = (depth/psize)*psize;
+ for(; k<peeled_k; k+=psize)
{
- for (Index m = 0; m < pack; m += PacketSize)
+ for (Index m = 0; m < pack; m += psize)
{
- PacketBlock<Packet> kernel;
- for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = lhs.loadPacket(i+p+m, k);
- ptranspose(kernel);
- for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
+ if (psize == PacketSize) {
+ PacketBlock<Packet> kernel;
+ for (int p = 0; p < psize; ++p) kernel.packet[p] = lhs.template loadPacket<Packet>(i+p+m, k);
+ ptranspose(kernel);
+ for (int p = 0; p < psize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
+ } else if (HasHalf && psize == HalfPacketSize) {
+ gone_half = true;
+ PacketBlock<HalfPacket> kernel_half;
+ for (int p = 0; p < psize; ++p) kernel_half.packet[p] = lhs.template loadPacket<HalfPacket>(i+p+m, k);
+ ptranspose(kernel_half);
+ for (int p = 0; p < psize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel_half.packet[p]));
+ } else if (HasQuarter && psize == QuarterPacketSize) {
+ gone_quarter = true;
+ PacketBlock<QuarterPacket> kernel_quarter;
+ for (int p = 0; p < psize; ++p) kernel_quarter.packet[p] = lhs.template loadPacket<QuarterPacket>(i+p+m, k);
+ ptranspose(kernel_quarter);
+ for (int p = 0; p < psize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel_quarter.packet[p]));
+ }
}
- count += PacketSize*pack;
+ count += psize*pack;
}
}
+
for(; k<depth; k++)
{
Index w=0;
@@ -1873,9 +2333,28 @@
if(PanelMode) count += pack * (stride-offset-depth);
}
- pack -= PacketSize;
- if(pack<Pack2 && (pack+PacketSize)!=Pack2)
- pack = Pack2;
+ pack -= psize;
+ Index left = rows - i;
+ if (pack <= 0) {
+ if (!gone_last &&
+ (starting_pos == i || left >= psize/2 || left >= psize/4) &&
+ ((psize/2 == HalfPacketSize && HasHalf && !gone_half) ||
+ (psize/2 == QuarterPacketSize && HasQuarter && !gone_quarter))) {
+ psize /= 2;
+ pack = psize;
+ continue;
+ }
+ // Pack2 may be *smaller* than PacketSize—that happens for
+ // products like real * complex, where we have to go half the
+ // progress on the lhs in order to duplicate those operands to
+ // address both real & imaginary parts on the rhs. This portion will
+ // pack those half ones until they match the number expected on the
+ // last peeling loop at this point (for the rhs).
+ if (Pack2 < PacketSize && !gone_last) {
+ gone_last = true;
+ psize = pack = left & ~1;
+ }
+ }
}
for(; i<rows; i++)
@@ -1931,7 +2410,7 @@
// const Scalar* b6 = &rhs[(j2+6)*rhsStride];
// const Scalar* b7 = &rhs[(j2+7)*rhsStride];
// Index k=0;
-// if(PacketSize==8) // TODO enbale vectorized transposition for PacketSize==4
+// if(PacketSize==8) // TODO enable vectorized transposition for PacketSize==4
// {
// for(; k<peeled_k; k+=PacketSize) {
// PacketBlock<Packet> kernel;
@@ -1978,10 +2457,10 @@
{
for(; k<peeled_k; k+=PacketSize) {
PacketBlock<Packet,(PacketSize%4)==0?4:PacketSize> kernel;
- kernel.packet[0] = dm0.loadPacket(k);
- kernel.packet[1%PacketSize] = dm1.loadPacket(k);
- kernel.packet[2%PacketSize] = dm2.loadPacket(k);
- kernel.packet[3%PacketSize] = dm3.loadPacket(k);
+ kernel.packet[0 ] = dm0.template loadPacket<Packet>(k);
+ kernel.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
+ kernel.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
+ kernel.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
ptranspose(kernel);
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
@@ -2022,94 +2501,104 @@
struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
{
typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename unpacket_traits<Packet>::half HalfPacket;
+ typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
typedef typename DataMapper::LinearMapper LinearMapper;
- enum { PacketSize = packet_traits<Scalar>::size };
- EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
-};
-
-template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
-EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
- ::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
-{
- EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
- EIGEN_UNUSED_VARIABLE(stride);
- EIGEN_UNUSED_VARIABLE(offset);
- eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
- conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
- Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
- Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
- Index count = 0;
-
-// if(nr>=8)
-// {
-// for(Index j2=0; j2<packet_cols8; j2+=8)
-// {
-// // skip what we have before
-// if(PanelMode) count += 8 * offset;
-// for(Index k=0; k<depth; k++)
-// {
-// if (PacketSize==8) {
-// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
-// pstoreu(blockB+count, cj.pconj(A));
-// } else if (PacketSize==4) {
-// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
-// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
-// pstoreu(blockB+count, cj.pconj(A));
-// pstoreu(blockB+count+PacketSize, cj.pconj(B));
-// } else {
-// const Scalar* b0 = &rhs[k*rhsStride + j2];
-// blockB[count+0] = cj(b0[0]);
-// blockB[count+1] = cj(b0[1]);
-// blockB[count+2] = cj(b0[2]);
-// blockB[count+3] = cj(b0[3]);
-// blockB[count+4] = cj(b0[4]);
-// blockB[count+5] = cj(b0[5]);
-// blockB[count+6] = cj(b0[6]);
-// blockB[count+7] = cj(b0[7]);
-// }
-// count += 8;
-// }
-// // skip what we have after
-// if(PanelMode) count += 8 * (stride-offset-depth);
-// }
-// }
- if(nr>=4)
+ enum { PacketSize = packet_traits<Scalar>::size,
+ HalfPacketSize = unpacket_traits<HalfPacket>::size,
+ QuarterPacketSize = unpacket_traits<QuarterPacket>::size};
+ EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0)
{
- for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
+ EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
+ EIGEN_UNUSED_VARIABLE(stride);
+ EIGEN_UNUSED_VARIABLE(offset);
+ eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
+ const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
+ const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
+ conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
+ Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
+ Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
+ Index count = 0;
+
+ // if(nr>=8)
+ // {
+ // for(Index j2=0; j2<packet_cols8; j2+=8)
+ // {
+ // // skip what we have before
+ // if(PanelMode) count += 8 * offset;
+ // for(Index k=0; k<depth; k++)
+ // {
+ // if (PacketSize==8) {
+ // Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
+ // pstoreu(blockB+count, cj.pconj(A));
+ // } else if (PacketSize==4) {
+ // Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
+ // Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
+ // pstoreu(blockB+count, cj.pconj(A));
+ // pstoreu(blockB+count+PacketSize, cj.pconj(B));
+ // } else {
+ // const Scalar* b0 = &rhs[k*rhsStride + j2];
+ // blockB[count+0] = cj(b0[0]);
+ // blockB[count+1] = cj(b0[1]);
+ // blockB[count+2] = cj(b0[2]);
+ // blockB[count+3] = cj(b0[3]);
+ // blockB[count+4] = cj(b0[4]);
+ // blockB[count+5] = cj(b0[5]);
+ // blockB[count+6] = cj(b0[6]);
+ // blockB[count+7] = cj(b0[7]);
+ // }
+ // count += 8;
+ // }
+ // // skip what we have after
+ // if(PanelMode) count += 8 * (stride-offset-depth);
+ // }
+ // }
+ if(nr>=4)
{
- // skip what we have before
- if(PanelMode) count += 4 * offset;
+ for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
+ {
+ // skip what we have before
+ if(PanelMode) count += 4 * offset;
+ for(Index k=0; k<depth; k++)
+ {
+ if (PacketSize==4) {
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
+ pstoreu(blockB+count, cj.pconj(A));
+ count += PacketSize;
+ } else if (HasHalf && HalfPacketSize==4) {
+ HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
+ pstoreu(blockB+count, cj.pconj(A));
+ count += HalfPacketSize;
+ } else if (HasQuarter && QuarterPacketSize==4) {
+ QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
+ pstoreu(blockB+count, cj.pconj(A));
+ count += QuarterPacketSize;
+ } else {
+ const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
+ blockB[count+0] = cj(dm0(0));
+ blockB[count+1] = cj(dm0(1));
+ blockB[count+2] = cj(dm0(2));
+ blockB[count+3] = cj(dm0(3));
+ count += 4;
+ }
+ }
+ // skip what we have after
+ if(PanelMode) count += 4 * (stride-offset-depth);
+ }
+ }
+ // copy the remaining columns one at a time (nr==1)
+ for(Index j2=packet_cols4; j2<cols; ++j2)
+ {
+ if(PanelMode) count += offset;
for(Index k=0; k<depth; k++)
{
- if (PacketSize==4) {
- Packet A = rhs.loadPacket(k, j2);
- pstoreu(blockB+count, cj.pconj(A));
- count += PacketSize;
- } else {
- const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
- blockB[count+0] = cj(dm0(0));
- blockB[count+1] = cj(dm0(1));
- blockB[count+2] = cj(dm0(2));
- blockB[count+3] = cj(dm0(3));
- count += 4;
- }
+ blockB[count] = cj(rhs(k, j2));
+ count += 1;
}
- // skip what we have after
- if(PanelMode) count += 4 * (stride-offset-depth);
+ if(PanelMode) count += stride-offset-depth;
}
}
- // copy the remaining columns one at a time (nr==1)
- for(Index j2=packet_cols4; j2<cols; ++j2)
- {
- if(PanelMode) count += offset;
- for(Index k=0; k<depth; k++)
- {
- blockB[count] = cj(rhs(k, j2));
- count += 1;
- }
- if(PanelMode) count += stride-offset-depth;
- }
-}
+};
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 6440e1d..caa65fc 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -20,8 +20,9 @@
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
-struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
{
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
@@ -30,7 +31,7 @@
Index rows, Index cols, Index depth,
const LhsScalar* lhs, Index lhsStride,
const RhsScalar* rhs, Index rhsStride,
- ResScalar* res, Index resStride,
+ ResScalar* res, Index resIncr, Index resStride,
ResScalar alpha,
level3_blocking<RhsScalar,LhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0)
@@ -39,8 +40,8 @@
general_matrix_matrix_product<Index,
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
- ColMajor>
- ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
+ ColMajor,ResInnerStride>
+ ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info);
}
};
@@ -49,8 +50,9 @@
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
-struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
{
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
@@ -59,23 +61,23 @@
static void run(Index rows, Index cols, Index depth,
const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsStride,
- ResScalar* _res, Index resStride,
+ ResScalar* _res, Index resIncr, Index resStride,
ResScalar alpha,
level3_blocking<LhsScalar,RhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0)
{
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
- LhsMapper lhs(_lhs,lhsStride);
- RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
+ LhsMapper lhs(_lhs, lhsStride);
+ RhsMapper rhs(_rhs, rhsStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
Index nc = (std::min)(cols,blocking.nc()); // cache block size along the N direction
- gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
+ gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
@@ -108,7 +110,7 @@
// i.e., we test that info[tid].users equals 0.
// Then, we set info[tid].users to the number of threads to mark that all other threads are going to use it.
while(info[tid].users!=0) {}
- info[tid].users += threads;
+ info[tid].users = threads;
pack_lhs(blockA+info[tid].lhs_start*actual_kc, lhs.getSubMapper(info[tid].lhs_start,k), actual_kc, info[tid].lhs_length);
@@ -146,7 +148,9 @@
// Release all the sub blocks A'_i of A' for the current thread,
// i.e., we simply decrement the number of users by 1
for(Index i=0; i<threads; ++i)
+#if !EIGEN_HAS_CXX11_ATOMIC
#pragma omp atomic
+#endif
info[i].users -= 1;
}
}
@@ -226,7 +230,7 @@
Gemm::run(rows, cols, m_lhs.cols(),
&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
- (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
+ (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(),
m_actualAlpha, m_blocking, info);
}
@@ -427,8 +431,14 @@
template<typename Dst>
static void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
- lazyproduct::evalTo(dst, lhs, rhs);
+ // See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=404 for a discussion and helper program
+ // to determine the following heuristic.
+ // EIGEN_GEMM_TO_COEFFBASED_THRESHOLD is typically defined to 20 in GeneralProduct.h,
+ // unless it has been specialized by the user or for a given architecture.
+ // Note that the condition rhs.rows()>0 was required because lazy product is (was?) not happy with empty inputs.
+ // I'm not sure it is still required.
+ if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
+ lazyproduct::eval_dynamic(dst, lhs, rhs, internal::assign_op<typename Dst::Scalar,Scalar>());
else
{
dst.setZero();
@@ -439,8 +449,8 @@
template<typename Dst>
static void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
- lazyproduct::addTo(dst, lhs, rhs);
+ if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
+ lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>());
else
scaleAndAddTo(dst,lhs, rhs, Scalar(1));
}
@@ -448,8 +458,8 @@
template<typename Dst>
static void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
- lazyproduct::subTo(dst, lhs, rhs);
+ if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
+ lazyproduct::eval_dynamic(dst, lhs, rhs, internal::sub_assign_op<typename Dst::Scalar,Scalar>());
else
scaleAndAddTo(dst, lhs, rhs, Scalar(-1));
}
@@ -461,11 +471,25 @@
if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
return;
+ if (dst.cols() == 1)
+ {
+ // Fallback to GEMV if either the lhs or rhs is a runtime vector
+ typename Dest::ColXpr dst_vec(dst.col(0));
+ return internal::generic_product_impl<Lhs,typename Rhs::ConstColXpr,DenseShape,DenseShape,GemvProduct>
+ ::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha);
+ }
+ else if (dst.rows() == 1)
+ {
+ // Fallback to GEMV if either the lhs or rhs is a runtime vector
+ typename Dest::RowXpr dst_vec(dst.row(0));
+ return internal::generic_product_impl<typename Lhs::ConstRowXpr,Rhs,DenseShape,DenseShape,GemvProduct>
+ ::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);
+ }
+
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
- * RhsBlasTraits::extractScalarFactor(a_rhs);
+ Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs);
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
@@ -476,7 +500,8 @@
Index,
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,
+ Dest::InnerStrideAtCompileTime>,
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
index e844e37..6ba0d9b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
@@ -25,51 +25,54 @@
**********************************************************************/
// forward declarations (defined at the end of this file)
-template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
+template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
struct tribb_kernel;
/* Optimized matrix-matrix product evaluating only one triangular half */
template <typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
- int ResStorageOrder, int UpLo, int Version = Specialized>
+ int ResStorageOrder, int ResInnerStride, int UpLo, int Version = Specialized>
struct general_matrix_matrix_triangular_product;
// as usual if the result is row major => we transpose the product
template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version>
-struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo,Version>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int UpLo, int Version>
+struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,UpLo,Version>
{
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
- const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride,
+ const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resIncr, Index resStride,
const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking)
{
general_matrix_matrix_triangular_product<Index,
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
- ColMajor, UpLo==Lower?Upper:Lower>
- ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking);
+ ColMajor, ResInnerStride, UpLo==Lower?Upper:Lower>
+ ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking);
}
};
template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version>
-struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Version>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int UpLo, int Version>
+struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version>
{
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
- const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride,
+ const RhsScalar* _rhs, Index rhsStride,
+ ResScalar* _res, Index resIncr, Index resStride,
const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
{
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc();
Index mc = (std::min)(size,blocking.mc());
@@ -84,10 +87,10 @@
ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
- gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
+ gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
- tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb;
+ tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo> sybb;
for(Index k2=0; k2<depth; k2+=kc)
{
@@ -110,8 +113,7 @@
gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
(std::min)(size,i2), alpha, -1, -1, 0, 0);
-
- sybb(_res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
+ sybb(_res+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
if (UpLo==Upper)
{
@@ -133,7 +135,7 @@
// while the triangular block overlapping the diagonal is evaluated into a
// small temporary buffer which is then accumulated into the result using a
// triangular traversal.
-template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
+template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
struct tribb_kernel
{
typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
@@ -142,11 +144,13 @@
enum {
BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret
};
- void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
+ void operator()(ResScalar* _res, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
{
- typedef blas_data_mapper<ResScalar, Index, ColMajor> ResMapper;
- ResMapper res(_res, resStride);
- gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel;
+ typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
+ typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
+ ResMapper res(_res, resStride, resIncr);
+ gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
+ gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert()));
@@ -158,31 +162,32 @@
const RhsScalar* actual_b = blockB+j*depth;
if(UpLo==Upper)
- gebp_kernel(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha,
- -1, -1, 0, 0);
-
+ gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha,
+ -1, -1, 0, 0);
+
// selfadjoint micro block
{
Index i = j;
buffer.setZero();
// 1 - apply the kernel on the temporary buffer
- gebp_kernel(ResMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
- -1, -1, 0, 0);
+ gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
+ -1, -1, 0, 0);
+
// 2 - triangular accumulation
for(Index j1=0; j1<actualBlockSize; ++j1)
{
- ResScalar* r = &res(i, j + j1);
+ typename ResMapper::LinearMapper r = res.getLinearMapper(i,j+j1);
for(Index i1=UpLo==Lower ? j1 : 0;
UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1)
- r[i1] += buffer(i1,j1);
+ r(i1) += buffer(i1,j1);
}
}
if(UpLo==Lower)
{
Index i = j+actualBlockSize;
- gebp_kernel(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i,
- depth, actualBlockSize, alpha, -1, -1, 0, 0);
+ gebp_kernel1(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i,
+ depth, actualBlockSize, alpha, -1, -1, 0, 0);
}
}
}
@@ -286,23 +291,24 @@
internal::general_matrix_matrix_triangular_product<Index,
typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
- IsRowMajor ? RowMajor : ColMajor, UpLo&(Lower|Upper)>
+ IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo&(Lower|Upper)>
::run(size, depth,
&actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(),
&actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(),
- mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? 1 : mat.outerStride() ) : 0), mat.outerStride(), actualAlpha, blocking);
+ mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? mat.innerStride() : mat.outerStride() ) : 0),
+ mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
}
};
template<typename MatrixType, unsigned int UpLo>
template<typename ProductType>
-TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta)
+EIGEN_DEVICE_FUNC TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta)
{
EIGEN_STATIC_ASSERT((UpLo&UnitDiag)==0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
-
+
general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta);
-
+
return derived();
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixVector.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixVector.h
index a597c1f..dfb6aeb 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixVector.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/GeneralMatrixVector.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2008-2016 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -14,11 +14,57 @@
namespace internal {
+enum GEMVPacketSizeType {
+ GEMVPacketFull = 0,
+ GEMVPacketHalf,
+ GEMVPacketQuarter
+};
+
+template <int N, typename T1, typename T2, typename T3>
+struct gemv_packet_cond { typedef T3 type; };
+
+template <typename T1, typename T2, typename T3>
+struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; };
+
+template <typename T1, typename T2, typename T3>
+struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; };
+
+template<typename LhsScalar, typename RhsScalar, int _PacketSize=GEMVPacketFull>
+class gemv_traits
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
+
+#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
+ typedef typename gemv_packet_cond<packet_size, \
+ typename packet_traits<name ## Scalar>::type, \
+ typename packet_traits<name ## Scalar>::half, \
+ typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
+ prefix ## name ## Packet
+
+ PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
+ PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
+#undef PACKET_DECL_COND_PREFIX
+
+public:
+ enum {
+ Vectorizable = unpacket_traits<_LhsPacket>::vectorizable &&
+ unpacket_traits<_RhsPacket>::vectorizable &&
+ int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size),
+ LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
+ RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
+ ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1
+ };
+
+ typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
+ typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
+ typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+};
+
+
/* Optimized col-major matrix * vector product:
- * This algorithm processes 4 columns at onces that allows to both reduce
- * the number of load/stores of the result by a factor 4 and to reduce
- * the instruction dependency. Moreover, we know that all bands have the
- * same alignment pattern.
+ * This algorithm processes the matrix per vertical panels,
+ * which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments.
*
* Mixing type logic: C += alpha * A * B
* | A | B |alpha| comments
@@ -27,56 +73,30 @@
* |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp
* |cplx |real |real | optimal case, vectorization possible via real-cplx mul
*
- * Accesses to the matrix coefficients follow the following logic:
- *
- * - if all columns have the same alignment then
- * - if the columns have the same alignment as the result vector, then easy! (-> AllAligned case)
- * - otherwise perform unaligned loads only (-> NoneAligned case)
- * - otherwise
- * - if even columns have the same alignment then
- * // odd columns are guaranteed to have the same alignment too
- * - if even or odd columns have the same alignment as the result, then
- * // for a register size of 2 scalars, this is guarantee to be the case (e.g., SSE with double)
- * - perform half aligned and half unaligned loads (-> EvenAligned case)
- * - otherwise perform unaligned loads only (-> NoneAligned case)
- * - otherwise, if the register size is 4 scalars (e.g., SSE with float) then
- * - one over 4 consecutive columns is guaranteed to be aligned with the result vector,
- * perform simple aligned loads for this column and aligned loads plus re-alignment for the other. (-> FirstAligned case)
- * // this re-alignment is done by the palign function implemented for SSE in Eigen/src/Core/arch/SSE/PacketMath.h
- * - otherwise,
- * // if we get here, this means the register size is greater than 4 (e.g., AVX with floats),
- * // we currently fall back to the NoneAligned case
- *
* The same reasoning apply for the transposed case.
- *
- * The last case (PacketSize>4) could probably be improved by generalizing the FirstAligned case, but since we do not support AVX yet...
- * One might also wonder why in the EvenAligned case we perform unaligned loads instead of using the aligned-loads plus re-alignment
- * strategy as in the FirstAligned case. The reason is that we observed that unaligned loads on a 8 byte boundary are not too slow
- * compared to unaligned loads on a 4 byte boundary.
- *
*/
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
{
+ typedef gemv_traits<LhsScalar,RhsScalar> Traits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
+
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
-enum {
- Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
- && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
-};
+ typedef typename Traits::LhsPacket LhsPacket;
+ typedef typename Traits::RhsPacket RhsPacket;
+ typedef typename Traits::ResPacket ResPacket;
-typedef typename packet_traits<LhsScalar>::type _LhsPacket;
-typedef typename packet_traits<RhsScalar>::type _RhsPacket;
-typedef typename packet_traits<ResScalar>::type _ResPacket;
+ typedef typename HalfTraits::LhsPacket LhsPacketHalf;
+ typedef typename HalfTraits::RhsPacket RhsPacketHalf;
+ typedef typename HalfTraits::ResPacket ResPacketHalf;
-typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
-typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
-typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+ typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
+ typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
+ typedef typename QuarterTraits::ResPacket ResPacketQuarter;
-EIGEN_DONT_INLINE static void run(
+EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
Index rows, Index cols,
const LhsMapper& lhs,
const RhsMapper& rhs,
@@ -85,244 +105,187 @@
};
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
-EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
+EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
Index rows, Index cols,
- const LhsMapper& lhs,
+ const LhsMapper& alhs,
const RhsMapper& rhs,
ResScalar* res, Index resIncr,
RhsScalar alpha)
{
EIGEN_UNUSED_VARIABLE(resIncr);
eigen_internal_assert(resIncr==1);
- #ifdef _EIGEN_ACCUMULATE_PACKETS
- #error _EIGEN_ACCUMULATE_PACKETS has already been defined
- #endif
- #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) \
- pstore(&res[j], \
- padd(pload<ResPacket>(&res[j]), \
- padd( \
- padd(pcj.pmul(lhs0.template load<LhsPacket, Alignment0>(j), ptmp0), \
- pcj.pmul(lhs1.template load<LhsPacket, Alignment13>(j), ptmp1)), \
- padd(pcj.pmul(lhs2.template load<LhsPacket, Alignment2>(j), ptmp2), \
- pcj.pmul(lhs3.template load<LhsPacket, Alignment13>(j), ptmp3)) )))
- typedef typename LhsMapper::VectorMapper LhsScalars;
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
+ // This helps GCC to generate propoer code.
+ LhsMapper lhs(alhs);
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
- if(ConjugateRhs)
- alpha = numext::conj(alpha);
-
- enum { AllAligned = 0, EvenAligned, FirstAligned, NoneAligned };
- const Index columnsAtOnce = 4;
- const Index peels = 2;
- const Index LhsPacketAlignedMask = LhsPacketSize-1;
- const Index ResPacketAlignedMask = ResPacketSize-1;
-// const Index PeelAlignedMask = ResPacketSize*peels-1;
- const Index size = rows;
+ conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
+ conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
const Index lhsStride = lhs.stride();
+ // TODO: for padded aligned inputs, we could enable aligned reads
+ enum { LhsAlignment = Unaligned,
+ ResPacketSize = Traits::ResPacketSize,
+ ResPacketSizeHalf = HalfTraits::ResPacketSize,
+ ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
+ LhsPacketSize = Traits::LhsPacketSize,
+ HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
+ HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
+ };
- // How many coeffs of the result do we have to skip to be aligned.
- // Here we assume data are at least aligned on the base scalar type.
- Index alignedStart = internal::first_default_aligned(res,size);
- Index alignedSize = ResPacketSize>1 ? alignedStart + ((size-alignedStart) & ~ResPacketAlignedMask) : 0;
- const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
+ const Index n8 = rows-8*ResPacketSize+1;
+ const Index n4 = rows-4*ResPacketSize+1;
+ const Index n3 = rows-3*ResPacketSize+1;
+ const Index n2 = rows-2*ResPacketSize+1;
+ const Index n1 = rows-1*ResPacketSize+1;
+ const Index n_half = rows-1*ResPacketSizeHalf+1;
+ const Index n_quarter = rows-1*ResPacketSizeQuarter+1;
- const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
- Index alignmentPattern = alignmentStep==0 ? AllAligned
- : alignmentStep==(LhsPacketSize/2) ? EvenAligned
- : FirstAligned;
+ // TODO: improve the following heuristic:
+ const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4);
+ ResPacket palpha = pset1<ResPacket>(alpha);
+ ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
+ ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
- // we cannot assume the first element is aligned because of sub-matrices
- const Index lhsAlignmentOffset = lhs.firstAligned(size);
-
- // find how many columns do we have to skip to be aligned with the result (if possible)
- Index skipColumns = 0;
- // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
- if( (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == size) || (UIntPtr(res)%sizeof(ResScalar)) )
+ for(Index j2=0; j2<cols; j2+=block_cols)
{
- alignedSize = 0;
- alignedStart = 0;
- alignmentPattern = NoneAligned;
- }
- else if(LhsPacketSize > 4)
- {
- // TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4.
- // Currently, it seems to be better to perform unaligned loads anyway
- alignmentPattern = NoneAligned;
- }
- else if (LhsPacketSize>1)
- {
- // eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || size<LhsPacketSize);
-
- while (skipColumns<LhsPacketSize &&
- alignedStart != ((lhsAlignmentOffset + alignmentStep*skipColumns)%LhsPacketSize))
- ++skipColumns;
- if (skipColumns==LhsPacketSize)
+ Index jend = numext::mini(j2+block_cols,cols);
+ Index i=0;
+ for(; i<n8; i+=ResPacketSize*8)
{
- // nothing can be aligned, no need to skip any column
- alignmentPattern = NoneAligned;
- skipColumns = 0;
- }
- else
- {
- skipColumns = (std::min)(skipColumns,cols);
- // note that the skiped columns are processed later.
- }
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0)),
+ c2 = pset1<ResPacket>(ResScalar(0)),
+ c3 = pset1<ResPacket>(ResScalar(0)),
+ c4 = pset1<ResPacket>(ResScalar(0)),
+ c5 = pset1<ResPacket>(ResScalar(0)),
+ c6 = pset1<ResPacket>(ResScalar(0)),
+ c7 = pset1<ResPacket>(ResScalar(0));
- /* eigen_internal_assert( (alignmentPattern==NoneAligned)
- || (skipColumns + columnsAtOnce >= cols)
- || LhsPacketSize > size
- || (size_t(firstLhs+alignedStart+lhsStride*skipColumns)%sizeof(LhsPacket))==0);*/
- }
- else if(Vectorizable)
- {
- alignedStart = 0;
- alignedSize = size;
- alignmentPattern = AllAligned;
- }
-
- const Index offset1 = (alignmentPattern==FirstAligned && alignmentStep==1)?3:1;
- const Index offset3 = (alignmentPattern==FirstAligned && alignmentStep==1)?1:3;
-
- Index columnBound = ((cols-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns;
- for (Index i=skipColumns; i<columnBound; i+=columnsAtOnce)
- {
- RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(i, 0)),
- ptmp1 = pset1<RhsPacket>(alpha*rhs(i+offset1, 0)),
- ptmp2 = pset1<RhsPacket>(alpha*rhs(i+2, 0)),
- ptmp3 = pset1<RhsPacket>(alpha*rhs(i+offset3, 0));
-
- // this helps a lot generating better binary code
- const LhsScalars lhs0 = lhs.getVectorMapper(0, i+0), lhs1 = lhs.getVectorMapper(0, i+offset1),
- lhs2 = lhs.getVectorMapper(0, i+2), lhs3 = lhs.getVectorMapper(0, i+offset3);
-
- if (Vectorizable)
- {
- /* explicit vectorization */
- // process initial unaligned coeffs
- for (Index j=0; j<alignedStart; ++j)
+ for(Index j=j2; j<jend; j+=1)
{
- res[j] = cj.pmadd(lhs0(j), pfirst(ptmp0), res[j]);
- res[j] = cj.pmadd(lhs1(j), pfirst(ptmp1), res[j]);
- res[j] = cj.pmadd(lhs2(j), pfirst(ptmp2), res[j]);
- res[j] = cj.pmadd(lhs3(j), pfirst(ptmp3), res[j]);
+ RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
+ c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
+ c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3);
+ c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*4,j),b0,c4);
+ c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*5,j),b0,c5);
+ c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*6,j),b0,c6);
+ c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*7,j),b0,c7);
}
-
- if (alignedSize>alignedStart)
- {
- switch(alignmentPattern)
- {
- case AllAligned:
- for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned);
- break;
- case EvenAligned:
- for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned);
- break;
- case FirstAligned:
- {
- Index j = alignedStart;
- if(peels>1)
- {
- LhsPacket A00, A01, A02, A03, A10, A11, A12, A13;
- ResPacket T0, T1;
-
- A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
- A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
- A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
-
- for (; j<peeledSize; j+=peels*ResPacketSize)
- {
- A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
- A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
- A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
-
- A00 = lhs0.template load<LhsPacket, Aligned>(j);
- A10 = lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize);
- T0 = pcj.pmadd(A00, ptmp0, pload<ResPacket>(&res[j]));
- T1 = pcj.pmadd(A10, ptmp0, pload<ResPacket>(&res[j+ResPacketSize]));
-
- T0 = pcj.pmadd(A01, ptmp1, T0);
- A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
- T0 = pcj.pmadd(A02, ptmp2, T0);
- A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
- T0 = pcj.pmadd(A03, ptmp3, T0);
- pstore(&res[j],T0);
- A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
- T1 = pcj.pmadd(A11, ptmp1, T1);
- T1 = pcj.pmadd(A12, ptmp2, T1);
- T1 = pcj.pmadd(A13, ptmp3, T1);
- pstore(&res[j+ResPacketSize],T1);
- }
- }
- for (; j<alignedSize; j+=ResPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned);
- break;
- }
- default:
- for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned);
- break;
- }
- }
- } // end explicit vectorization
-
- /* process remaining coeffs (or all if there is no explicit vectorization) */
- for (Index j=alignedSize; j<size; ++j)
+ pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
+ pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
+ pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
+ pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3)));
+ pstoreu(res+i+ResPacketSize*4, pmadd(c4,palpha,ploadu<ResPacket>(res+i+ResPacketSize*4)));
+ pstoreu(res+i+ResPacketSize*5, pmadd(c5,palpha,ploadu<ResPacket>(res+i+ResPacketSize*5)));
+ pstoreu(res+i+ResPacketSize*6, pmadd(c6,palpha,ploadu<ResPacket>(res+i+ResPacketSize*6)));
+ pstoreu(res+i+ResPacketSize*7, pmadd(c7,palpha,ploadu<ResPacket>(res+i+ResPacketSize*7)));
+ }
+ if(i<n4)
{
- res[j] = cj.pmadd(lhs0(j), pfirst(ptmp0), res[j]);
- res[j] = cj.pmadd(lhs1(j), pfirst(ptmp1), res[j]);
- res[j] = cj.pmadd(lhs2(j), pfirst(ptmp2), res[j]);
- res[j] = cj.pmadd(lhs3(j), pfirst(ptmp3), res[j]);
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0)),
+ c2 = pset1<ResPacket>(ResScalar(0)),
+ c3 = pset1<ResPacket>(ResScalar(0));
+
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
+ c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
+ c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3);
+ }
+ pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
+ pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
+ pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
+ pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3)));
+
+ i+=ResPacketSize*4;
+ }
+ if(i<n3)
+ {
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0)),
+ c2 = pset1<ResPacket>(ResScalar(0));
+
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
+ c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
+ }
+ pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
+ pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
+ pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
+
+ i+=ResPacketSize*3;
+ }
+ if(i<n2)
+ {
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0));
+
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
+ }
+ pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
+ pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
+ i+=ResPacketSize*2;
+ }
+ if(i<n1)
+ {
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0));
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
+ }
+ pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
+ i+=ResPacketSize;
+ }
+ if(HasHalf && i<n_half)
+ {
+ ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0));
+ c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0);
+ }
+ pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0)));
+ i+=ResPacketSizeHalf;
+ }
+ if(HasQuarter && i<n_quarter)
+ {
+ ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
+ for(Index j=j2; j<jend; j+=1)
+ {
+ RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0));
+ c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0);
+ }
+ pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0)));
+ i+=ResPacketSizeQuarter;
+ }
+ for(;i<rows;++i)
+ {
+ ResScalar c0(0);
+ for(Index j=j2; j<jend; j+=1)
+ c0 += cj.pmul(lhs(i,j), rhs(j,0));
+ res[i] += alpha*c0;
}
}
-
- // process remaining first and last columns (at most columnsAtOnce-1)
- Index end = cols;
- Index start = columnBound;
- do
- {
- for (Index k=start; k<end; ++k)
- {
- RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(k, 0));
- const LhsScalars lhs0 = lhs.getVectorMapper(0, k);
-
- if (Vectorizable)
- {
- /* explicit vectorization */
- // process first unaligned result's coeffs
- for (Index j=0; j<alignedStart; ++j)
- res[j] += cj.pmul(lhs0(j), pfirst(ptmp0));
- // process aligned result's coeffs
- if (lhs0.template aligned<LhsPacket>(alignedStart))
- for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
- pstore(&res[i], pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(i), ptmp0, pload<ResPacket>(&res[i])));
- else
- for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
- pstore(&res[i], pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(i), ptmp0, pload<ResPacket>(&res[i])));
- }
-
- // process remaining scalars (or all if no explicit vectorization)
- for (Index i=alignedSize; i<size; ++i)
- res[i] += cj.pmul(lhs0(i), pfirst(ptmp0));
- }
- if (skipColumns)
- {
- start = 0;
- end = skipColumns;
- skipColumns = 0;
- }
- else
- break;
- } while(Vectorizable);
- #undef _EIGEN_ACCUMULATE_PACKETS
}
/* Optimized row-major matrix * vector product:
- * This algorithm processes 4 rows at onces that allows to both reduce
+ * This algorithm processes 4 rows at once that allows to both reduce
* the number of load/stores of the result by a factor 4 and to reduce
* the instruction dependency. Moreover, we know that all bands have the
* same alignment pattern.
@@ -334,25 +297,25 @@
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
{
-typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
+ typedef gemv_traits<LhsScalar,RhsScalar> Traits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
+ typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
-enum {
- Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
- && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
- LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
- RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
- ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
-};
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
-typedef typename packet_traits<LhsScalar>::type _LhsPacket;
-typedef typename packet_traits<RhsScalar>::type _RhsPacket;
-typedef typename packet_traits<ResScalar>::type _ResPacket;
+ typedef typename Traits::LhsPacket LhsPacket;
+ typedef typename Traits::RhsPacket RhsPacket;
+ typedef typename Traits::ResPacket ResPacket;
-typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
-typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
-typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
+ typedef typename HalfTraits::LhsPacket LhsPacketHalf;
+ typedef typename HalfTraits::RhsPacket RhsPacketHalf;
+ typedef typename HalfTraits::ResPacket ResPacketHalf;
-EIGEN_DONT_INLINE static void run(
+ typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
+ typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
+ typedef typename QuarterTraits::ResPacket ResPacketQuarter;
+
+EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
Index rows, Index cols,
const LhsMapper& lhs,
const RhsMapper& rhs,
@@ -361,255 +324,191 @@
};
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
-EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
+EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
Index rows, Index cols,
- const LhsMapper& lhs,
+ const LhsMapper& alhs,
const RhsMapper& rhs,
ResScalar* res, Index resIncr,
ResScalar alpha)
{
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
+ // This helps GCC to generate propoer code.
+ LhsMapper lhs(alhs);
+
eigen_internal_assert(rhs.stride()==1);
-
- #ifdef _EIGEN_ACCUMULATE_PACKETS
- #error _EIGEN_ACCUMULATE_PACKETS has already been defined
- #endif
-
- #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\
- RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \
- ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \
- ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \
- ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \
- ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); }
-
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
+ conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
+ conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
- typedef typename LhsMapper::VectorMapper LhsScalars;
+ // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
+ // processing 8 rows at once might be counter productive wrt cache.
+ const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7;
+ const Index n4 = rows-3;
+ const Index n2 = rows-1;
- enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
- const Index rowsAtOnce = 4;
- const Index peels = 2;
- const Index RhsPacketAlignedMask = RhsPacketSize-1;
- const Index LhsPacketAlignedMask = LhsPacketSize-1;
- const Index depth = cols;
- const Index lhsStride = lhs.stride();
+ // TODO: for padded aligned inputs, we could enable aligned reads
+ enum { LhsAlignment = Unaligned,
+ ResPacketSize = Traits::ResPacketSize,
+ ResPacketSizeHalf = HalfTraits::ResPacketSize,
+ ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
+ LhsPacketSize = Traits::LhsPacketSize,
+ LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
+ LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
+ HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
+ HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
+ };
- // How many coeffs of the result do we have to skip to be aligned.
- // Here we assume data are at least aligned on the base scalar type
- // if that's not the case then vectorization is discarded, see below.
- Index alignedStart = rhs.firstAligned(depth);
- Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
- const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
-
- const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
- Index alignmentPattern = alignmentStep==0 ? AllAligned
- : alignmentStep==(LhsPacketSize/2) ? EvenAligned
- : FirstAligned;
-
- // we cannot assume the first element is aligned because of sub-matrices
- const Index lhsAlignmentOffset = lhs.firstAligned(depth);
- const Index rhsAlignmentOffset = rhs.firstAligned(rows);
-
- // find how many rows do we have to skip to be aligned with rhs (if possible)
- Index skipRows = 0;
- // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
- if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) ||
- (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
- (rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
+ Index i=0;
+ for(; i<n8; i+=8)
{
- alignedSize = 0;
- alignedStart = 0;
- alignmentPattern = NoneAligned;
- }
- else if(LhsPacketSize > 4)
- {
- // TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4.
- alignmentPattern = NoneAligned;
- }
- else if (LhsPacketSize>1)
- {
- // eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || depth<LhsPacketSize);
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0)),
+ c2 = pset1<ResPacket>(ResScalar(0)),
+ c3 = pset1<ResPacket>(ResScalar(0)),
+ c4 = pset1<ResPacket>(ResScalar(0)),
+ c5 = pset1<ResPacket>(ResScalar(0)),
+ c6 = pset1<ResPacket>(ResScalar(0)),
+ c7 = pset1<ResPacket>(ResScalar(0));
- while (skipRows<LhsPacketSize &&
- alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
- ++skipRows;
- if (skipRows==LhsPacketSize)
+ Index j=0;
+ for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
- // nothing can be aligned, no need to skip any column
- alignmentPattern = NoneAligned;
- skipRows = 0;
+ RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
+
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
+ c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
+ c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
+ c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4);
+ c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5);
+ c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6);
+ c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7);
}
- else
+ ResScalar cc0 = predux(c0);
+ ResScalar cc1 = predux(c1);
+ ResScalar cc2 = predux(c2);
+ ResScalar cc3 = predux(c3);
+ ResScalar cc4 = predux(c4);
+ ResScalar cc5 = predux(c5);
+ ResScalar cc6 = predux(c6);
+ ResScalar cc7 = predux(c7);
+ for(; j<cols; ++j)
{
- skipRows = (std::min)(skipRows,Index(rows));
- // note that the skiped columns are processed later.
+ RhsScalar b0 = rhs(j,0);
+
+ cc0 += cj.pmul(lhs(i+0,j), b0);
+ cc1 += cj.pmul(lhs(i+1,j), b0);
+ cc2 += cj.pmul(lhs(i+2,j), b0);
+ cc3 += cj.pmul(lhs(i+3,j), b0);
+ cc4 += cj.pmul(lhs(i+4,j), b0);
+ cc5 += cj.pmul(lhs(i+5,j), b0);
+ cc6 += cj.pmul(lhs(i+6,j), b0);
+ cc7 += cj.pmul(lhs(i+7,j), b0);
}
- /* eigen_internal_assert( alignmentPattern==NoneAligned
- || LhsPacketSize==1
- || (skipRows + rowsAtOnce >= rows)
- || LhsPacketSize > depth
- || (size_t(firstLhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0);*/
+ res[(i+0)*resIncr] += alpha*cc0;
+ res[(i+1)*resIncr] += alpha*cc1;
+ res[(i+2)*resIncr] += alpha*cc2;
+ res[(i+3)*resIncr] += alpha*cc3;
+ res[(i+4)*resIncr] += alpha*cc4;
+ res[(i+5)*resIncr] += alpha*cc5;
+ res[(i+6)*resIncr] += alpha*cc6;
+ res[(i+7)*resIncr] += alpha*cc7;
}
- else if(Vectorizable)
+ for(; i<n4; i+=4)
{
- alignedStart = 0;
- alignedSize = depth;
- alignmentPattern = AllAligned;
- }
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0)),
+ c2 = pset1<ResPacket>(ResScalar(0)),
+ c3 = pset1<ResPacket>(ResScalar(0));
- const Index offset1 = (alignmentPattern==FirstAligned && alignmentStep==1)?3:1;
- const Index offset3 = (alignmentPattern==FirstAligned && alignmentStep==1)?1:3;
-
- Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
- for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
- {
- // FIXME: what is the purpose of this EIGEN_ALIGN_DEFAULT ??
- EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
- ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
-
- // this helps the compiler generating good binary code
- const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0),
- lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0);
-
- if (Vectorizable)
+ Index j=0;
+ for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
- /* explicit vectorization */
- ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
- ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
+ RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
- // process initial unaligned coeffs
- // FIXME this loop get vectorized by the compiler !
- for (Index j=0; j<alignedStart; ++j)
- {
- RhsScalar b = rhs(j, 0);
- tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
- tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
- }
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
+ c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
+ c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
+ }
+ ResScalar cc0 = predux(c0);
+ ResScalar cc1 = predux(c1);
+ ResScalar cc2 = predux(c2);
+ ResScalar cc3 = predux(c3);
+ for(; j<cols; ++j)
+ {
+ RhsScalar b0 = rhs(j,0);
- if (alignedSize>alignedStart)
- {
- switch(alignmentPattern)
+ cc0 += cj.pmul(lhs(i+0,j), b0);
+ cc1 += cj.pmul(lhs(i+1,j), b0);
+ cc2 += cj.pmul(lhs(i+2,j), b0);
+ cc3 += cj.pmul(lhs(i+3,j), b0);
+ }
+ res[(i+0)*resIncr] += alpha*cc0;
+ res[(i+1)*resIncr] += alpha*cc1;
+ res[(i+2)*resIncr] += alpha*cc2;
+ res[(i+3)*resIncr] += alpha*cc3;
+ }
+ for(; i<n2; i+=2)
+ {
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
+ c1 = pset1<ResPacket>(ResScalar(0));
+
+ Index j=0;
+ for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
+ {
+ RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
+
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
+ c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
+ }
+ ResScalar cc0 = predux(c0);
+ ResScalar cc1 = predux(c1);
+ for(; j<cols; ++j)
+ {
+ RhsScalar b0 = rhs(j,0);
+
+ cc0 += cj.pmul(lhs(i+0,j), b0);
+ cc1 += cj.pmul(lhs(i+1,j), b0);
+ }
+ res[(i+0)*resIncr] += alpha*cc0;
+ res[(i+1)*resIncr] += alpha*cc1;
+ }
+ for(; i<rows; ++i)
+ {
+ ResPacket c0 = pset1<ResPacket>(ResScalar(0));
+ ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
+ ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
+ Index j=0;
+ for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
+ {
+ RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0);
+ c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
+ }
+ ResScalar cc0 = predux(c0);
+ if (HasHalf) {
+ for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf)
{
- case AllAligned:
- for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned);
- break;
- case EvenAligned:
- for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned);
- break;
- case FirstAligned:
- {
- Index j = alignedStart;
- if (peels>1)
- {
- /* Here we proccess 4 rows with with two peeled iterations to hide
- * the overhead of unaligned loads. Moreover unaligned loads are handled
- * using special shift/move operations between the two aligned packets
- * overlaping the desired unaligned packet. This is *much* more efficient
- * than basic unaligned loads.
- */
- LhsPacket A01, A02, A03, A11, A12, A13;
- A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
- A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
- A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
-
- for (; j<peeledSize; j+=peels*RhsPacketSize)
- {
- RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0);
- A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
- A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
- A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
-
- ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0);
- ptmp1 = pcj.pmadd(A01, b, ptmp1);
- A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
- ptmp2 = pcj.pmadd(A02, b, ptmp2);
- A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
- ptmp3 = pcj.pmadd(A03, b, ptmp3);
- A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
-
- b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0);
- ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0);
- ptmp1 = pcj.pmadd(A11, b, ptmp1);
- ptmp2 = pcj.pmadd(A12, b, ptmp2);
- ptmp3 = pcj.pmadd(A13, b, ptmp3);
- }
- }
- for (; j<alignedSize; j+=RhsPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned);
- break;
- }
- default:
- for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
- _EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned);
- break;
+ RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0);
+ c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h);
}
- tmp0 += predux(ptmp0);
- tmp1 += predux(ptmp1);
- tmp2 += predux(ptmp2);
- tmp3 += predux(ptmp3);
- }
- } // end explicit vectorization
-
- // process remaining coeffs (or all if no explicit vectorization)
- // FIXME this loop get vectorized by the compiler !
- for (Index j=alignedSize; j<depth; ++j)
- {
- RhsScalar b = rhs(j, 0);
- tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
- tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
+ cc0 += predux(c0_h);
}
- res[i*resIncr] += alpha*tmp0;
- res[(i+offset1)*resIncr] += alpha*tmp1;
- res[(i+2)*resIncr] += alpha*tmp2;
- res[(i+offset3)*resIncr] += alpha*tmp3;
+ if (HasQuarter) {
+ for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter)
+ {
+ RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0);
+ c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q);
+ }
+ cc0 += predux(c0_q);
+ }
+ for(; j<cols; ++j)
+ {
+ cc0 += cj.pmul(lhs(i,j), rhs(j,0));
+ }
+ res[i*resIncr] += alpha*cc0;
}
-
- // process remaining first and last rows (at most columnsAtOnce-1)
- Index end = rows;
- Index start = rowBound;
- do
- {
- for (Index i=start; i<end; ++i)
- {
- EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
- ResPacket ptmp0 = pset1<ResPacket>(tmp0);
- const LhsScalars lhs0 = lhs.getVectorMapper(i, 0);
- // process first unaligned result's coeffs
- // FIXME this loop get vectorized by the compiler !
- for (Index j=0; j<alignedStart; ++j)
- tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
-
- if (alignedSize>alignedStart)
- {
- // process aligned rhs coeffs
- if (lhs0.template aligned<LhsPacket>(alignedStart))
- for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
- ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
- else
- for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
- ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
- tmp0 += predux(ptmp0);
- }
-
- // process remaining scalars
- // FIXME this loop get vectorized by the compiler !
- for (Index j=alignedSize; j<depth; ++j)
- tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
- res[i*resIncr] += alpha*tmp0;
- }
- if (skipRows)
- {
- start = 0;
- end = skipRows;
- skipRows = 0;
- }
- else
- break;
- } while(Vectorizable);
-
- #undef _EIGEN_ACCUMULATE_PACKETS
}
} // end namespace internal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/Parallelizer.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/Parallelizer.h
index c2f084c..8f91879 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/Parallelizer.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/Parallelizer.h
@@ -10,6 +10,10 @@
#ifndef EIGEN_PARALLELIZER_H
#define EIGEN_PARALLELIZER_H
+#if EIGEN_HAS_CXX11_ATOMIC
+#include <atomic>
+#endif
+
namespace Eigen {
namespace internal {
@@ -17,7 +21,8 @@
/** \internal */
inline void manage_multi_threading(Action action, int* v)
{
- static EIGEN_UNUSED int m_maxThreads = -1;
+ static int m_maxThreads = -1;
+ EIGEN_UNUSED_VARIABLE(m_maxThreads)
if(action==SetAction)
{
@@ -75,8 +80,17 @@
{
GemmParallelInfo() : sync(-1), users(0), lhs_start(0), lhs_length(0) {}
+ // volatile is not enough on all architectures (see bug 1572)
+ // to guarantee that when thread A says to thread B that it is
+ // done with packing a block, then all writes have been really
+ // carried out... C++11 memory model+atomic guarantees this.
+#if EIGEN_HAS_CXX11_ATOMIC
+ std::atomic<Index> sync;
+ std::atomic<int> users;
+#else
Index volatile sync;
int volatile users;
+#endif
Index lhs_start;
Index lhs_length;
@@ -87,11 +101,14 @@
{
// TODO when EIGEN_USE_BLAS is defined,
// we should still enable OMP for other scalar types
-#if !(defined (EIGEN_HAS_OPENMP)) || defined (EIGEN_USE_BLAS)
+ // Without C++11, we have to disable GEMM's parallelization on
+ // non x86 architectures because there volatile is not enough for our purpose.
+ // See bug 1572.
+#if (! defined(EIGEN_HAS_OPENMP)) || defined(EIGEN_USE_BLAS) || ((!EIGEN_HAS_CXX11_ATOMIC) && !(EIGEN_ARCH_i386_OR_x86_64))
// FIXME the transpose variable is only needed to properly split
// the matrix product when multithreading is enabled. This is a temporary
// fix to support row-major destination matrices. This whole
- // parallelizer mechanism has to be redisigned anyway.
+ // parallelizer mechanism has to be redesigned anyway.
EIGEN_UNUSED_VARIABLE(depth);
EIGEN_UNUSED_VARIABLE(transpose);
func(0,rows, 0,cols);
@@ -112,12 +129,12 @@
double work = static_cast<double>(rows) * static_cast<double>(cols) *
static_cast<double>(depth);
double kMinTaskSize = 50000; // FIXME improve this heuristic.
- pb_max_threads = std::max<Index>(1, std::min<Index>(pb_max_threads, work / kMinTaskSize));
+ pb_max_threads = std::max<Index>(1, std::min<Index>(pb_max_threads, static_cast<Index>( work / kMinTaskSize ) ));
// compute the number of threads we are going to use
Index threads = std::min<Index>(nbThreads(), pb_max_threads);
- // if multi-threading is explicitely disabled, not useful, or if we already are in a parallel session,
+ // if multi-threading is explicitly disabled, not useful, or if we already are in a parallel session,
// then abort multi-threading
// FIXME omp_get_num_threads()>1 only works for openmp, what if the user does not use openmp?
if((!Condition) || (threads==1) || (omp_get_num_threads()>1))
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
index da6f82a..33ecf10 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
@@ -45,14 +45,23 @@
}
void operator()(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
{
- enum { PacketSize = packet_traits<Scalar>::size };
+ typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket;
+ typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half QuarterPacket;
+ enum { PacketSize = packet_traits<Scalar>::size,
+ HalfPacketSize = unpacket_traits<HalfPacket>::size,
+ QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
+ HasHalf = (int)HalfPacketSize < (int)PacketSize,
+ HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize};
+
const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(_lhs,lhsStride);
Index count = 0;
//Index peeled_mc3 = (rows/Pack1)*Pack1;
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
- const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
+ const Index peeled_mc1 = Pack1>=1*PacketSize ? peeled_mc2+((rows-peeled_mc2)/(1*PacketSize))*(1*PacketSize) : 0;
+ const Index peeled_mc_half = Pack1>=HalfPacketSize ? peeled_mc1+((rows-peeled_mc1)/(HalfPacketSize))*(HalfPacketSize) : 0;
+ const Index peeled_mc_quarter = Pack1>=QuarterPacketSize ? peeled_mc_half+((rows-peeled_mc_half)/(QuarterPacketSize))*(QuarterPacketSize) : 0;
if(Pack1>=3*PacketSize)
for(Index i=0; i<peeled_mc3; i+=3*PacketSize)
@@ -66,8 +75,16 @@
for(Index i=peeled_mc2; i<peeled_mc1; i+=1*PacketSize)
pack<1*PacketSize>(blockA, lhs, cols, i, count);
+ if(HasHalf && Pack1>=HalfPacketSize)
+ for(Index i=peeled_mc1; i<peeled_mc_half; i+=HalfPacketSize)
+ pack<HalfPacketSize>(blockA, lhs, cols, i, count);
+
+ if(HasQuarter && Pack1>=QuarterPacketSize)
+ for(Index i=peeled_mc_half; i<peeled_mc_quarter; i+=QuarterPacketSize)
+ pack<QuarterPacketSize>(blockA, lhs, cols, i, count);
+
// do the same with mr==1
- for(Index i=peeled_mc1; i<rows; i++)
+ for(Index i=peeled_mc_quarter; i<rows; i++)
{
for(Index k=0; k<i; k++)
blockA[count++] = lhs(i, k); // normal
@@ -277,20 +294,21 @@
template <typename Scalar, typename Index,
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
- int ResStorageOrder>
+ int ResStorageOrder, int ResInnerStride>
struct product_selfadjoint_matrix;
template <typename Scalar, typename Index,
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
- int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs>
-struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor>
+ int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
+ int ResInnerStride>
+struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor,ResInnerStride>
{
static EIGEN_STRONG_INLINE void run(
Index rows, Index cols,
const Scalar* lhs, Index lhsStride,
const Scalar* rhs, Index rhsStride,
- Scalar* res, Index resStride,
+ Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{
product_selfadjoint_matrix<Scalar, Index,
@@ -298,33 +316,35 @@
RhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsSelfAdjoint,ConjugateRhs),
EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs),
- ColMajor>
- ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking);
+ ColMajor,ResInnerStride>
+ ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
}
};
template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs>
-struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>
{
static EIGEN_DONT_INLINE void run(
Index rows, Index cols,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* res, Index resStride,
+ Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs>
-EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor>::run(
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run(
Index rows, Index cols,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* _res, Index resStride,
+ Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{
Index size = rows;
@@ -334,11 +354,11 @@
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
LhsTransposeMapper lhs_transpose(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@@ -352,7 +372,7 @@
gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
- gemm_pack_lhs<Scalar, Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed;
+ gemm_pack_lhs<Scalar, Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed;
for(Index k2=0; k2<size; k2+=kc)
{
@@ -387,7 +407,7 @@
for(Index i2=k2+kc; i2<size; i2+=mc)
{
const Index actual_mc = (std::min)(i2+mc,size)-i2;
- gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder,false>()
+ gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder,false>()
(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
@@ -398,26 +418,28 @@
// matrix * selfadjoint product
template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs>
-struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>
{
static EIGEN_DONT_INLINE void run(
Index rows, Index cols,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* res, Index resStride,
+ Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs>
-EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor>::run(
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
Index rows, Index cols,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* _res, Index resStride,
+ Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{
Index size = cols;
@@ -425,9 +447,9 @@
typedef gebp_traits<Scalar,Scalar> Traits;
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
- ResMapper res(_res,resStride);
+ ResMapper res(_res,resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@@ -437,7 +459,7 @@
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
- gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
+ gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
symm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
for(Index k2=0; k2<size; k2+=kc)
@@ -503,12 +525,13 @@
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)),
EIGEN_LOGICAL_XOR(RhsIsUpper,internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)),
- internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
+ internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor,
+ Dest::InnerStrideAtCompileTime>
::run(
lhs.rows(), rhs.cols(), // sizes
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
- &dst.coeffRef(0,0), dst.outerStride(), // result info
+ &dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(), // result info
actualAlpha, blocking // alpha
);
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixVector.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixVector.h
index 3fd180e..d38fd72 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixVector.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointMatrixVector.h
@@ -15,7 +15,7 @@
namespace internal {
/* Optimized selfadjoint matrix * vector product:
- * This algorithm processes 2 columns at onces that allows to both reduce
+ * This algorithm processes 2 columns at once that allows to both reduce
* the number of load/stores of the result by a factor 2 and to reduce
* the instruction dependency.
*/
@@ -27,7 +27,8 @@
struct selfadjoint_matrix_vector_product
{
-static EIGEN_DONT_INLINE void run(
+static EIGEN_DONT_INLINE EIGEN_DEVICE_FUNC
+void run(
Index size,
const Scalar* lhs, Index lhsStride,
const Scalar* rhs,
@@ -36,7 +37,8 @@
};
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs, int Version>
-EIGEN_DONT_INLINE void selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Version>::run(
+EIGEN_DONT_INLINE EIGEN_DEVICE_FUNC
+void selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Version>::run(
Index size,
const Scalar* lhs, Index lhsStride,
const Scalar* rhs,
@@ -62,8 +64,7 @@
Scalar cjAlpha = ConjugateRhs ? numext::conj(alpha) : alpha;
-
- Index bound = (std::max)(Index(0),size-8) & 0xfffffffe;
+ Index bound = numext::maxi(Index(0), size-8) & 0xfffffffe;
if (FirstTriangular)
bound = size - bound;
@@ -175,7 +176,8 @@
enum { LhsUpLo = LhsMode&(Upper|Lower) };
template<typename Dest>
- static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
+ static EIGEN_DEVICE_FUNC
+ void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
{
typedef typename Dest::Scalar ResScalar;
typedef typename Rhs::Scalar RhsScalar;
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointProduct.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointProduct.h
index f038d68..a21be80 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointProduct.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointProduct.h
@@ -109,10 +109,10 @@
internal::general_matrix_matrix_triangular_product<Index,
Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
- IsRowMajor ? RowMajor : ColMajor, UpLo>
+ IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo>
::run(size, depth,
- &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
- mat.data(), mat.outerStride(), actualAlpha, blocking);
+ actualOther.data(), actualOther.outerStride(), actualOther.data(), actualOther.outerStride(),
+ mat.data(), mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
}
};
@@ -120,7 +120,7 @@
template<typename MatrixType, unsigned int UpLo>
template<typename DerivedU>
-SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
+EIGEN_DEVICE_FUNC SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
::rankUpdate(const MatrixBase<DerivedU>& u, const Scalar& alpha)
{
selfadjoint_product_selector<MatrixType,DerivedU,UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointRank2Update.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointRank2Update.h
index 2ae3641..f752a0b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointRank2Update.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/SelfadjointRank2Update.h
@@ -24,7 +24,8 @@
template<typename Scalar, typename Index, typename UType, typename VType>
struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Lower>
{
- static void run(Scalar* mat, Index stride, const UType& u, const VType& v, const Scalar& alpha)
+ static EIGEN_DEVICE_FUNC
+ void run(Scalar* mat, Index stride, const UType& u, const VType& v, const Scalar& alpha)
{
const Index size = u.size();
for (Index i=0; i<size; ++i)
@@ -57,7 +58,7 @@
template<typename MatrixType, unsigned int UpLo>
template<typename DerivedU, typename DerivedV>
-SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
+EIGEN_DEVICE_FUNC SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
::rankUpdate(const MatrixBase<DerivedU>& u, const MatrixBase<DerivedV>& v, const Scalar& alpha)
{
typedef internal::blas_traits<DerivedU> UBlasTraits;
@@ -79,8 +80,8 @@
if (IsRowMajor)
actualAlpha = numext::conj(actualAlpha);
- typedef typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ UBlasTraits::NeedToConjugate,_ActualUType>::type>::type UType;
- typedef typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ VBlasTraits::NeedToConjugate,_ActualVType>::type>::type VType;
+ typedef typename internal::remove_all<typename internal::conj_expr_if<int(IsRowMajor) ^ int(UBlasTraits::NeedToConjugate), _ActualUType>::type>::type UType;
+ typedef typename internal::remove_all<typename internal::conj_expr_if<int(IsRowMajor) ^ int(VBlasTraits::NeedToConjugate), _ActualVType>::type>::type VType;
internal::selfadjoint_rank2_update_selector<Scalar, Index, UType, VType,
(IsRowMajor ? int(UpLo==Upper ? Lower : Upper) : UpLo)>
::run(_expression().const_cast_derived().data(),_expression().outerStride(),UType(actualU),VType(actualV),actualAlpha);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularMatrixMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularMatrixMatrix.h
index f784507..f0c6050 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularMatrixMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularMatrixMatrix.h
@@ -45,22 +45,24 @@
int Mode, bool LhsIsTriangular,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs,
- int ResStorageOrder, int Version = Specialized>
+ int ResStorageOrder, int ResInnerStride,
+ int Version = Specialized>
struct product_triangular_matrix_matrix;
template <typename Scalar, typename Index,
int Mode, bool LhsIsTriangular,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs, int Version>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int Version>
struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
LhsStorageOrder,ConjugateLhs,
- RhsStorageOrder,ConjugateRhs,RowMajor,Version>
+ RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,Version>
{
static EIGEN_STRONG_INLINE void run(
Index rows, Index cols, Index depth,
const Scalar* lhs, Index lhsStride,
const Scalar* rhs, Index rhsStride,
- Scalar* res, Index resStride,
+ Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{
product_triangular_matrix_matrix<Scalar, Index,
@@ -70,18 +72,19 @@
ConjugateRhs,
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
ConjugateLhs,
- ColMajor>
- ::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking);
+ ColMajor, ResInnerStride>
+ ::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
}
};
// implements col-major += alpha * op(triangular) * op(general)
template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs, int Version>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int Version>
struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
LhsStorageOrder,ConjugateLhs,
- RhsStorageOrder,ConjugateRhs,ColMajor,Version>
+ RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
{
typedef gebp_traits<Scalar,Scalar> Traits;
@@ -95,20 +98,21 @@
Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* res, Index resStride,
+ Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs, int Version>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int Version>
EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true,
LhsStorageOrder,ConjugateLhs,
- RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run(
+ RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>::run(
Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* _res, Index resStride,
+ Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{
// strip zeros
@@ -119,10 +123,10 @@
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@@ -151,7 +155,7 @@
triangularBuffer.diagonal().setOnes();
gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
- gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
+ gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
for(Index k2=IsLower ? depth : 0;
@@ -222,7 +226,7 @@
for(Index i2=start; i2<end; i2+=mc)
{
const Index actual_mc = (std::min)(i2+mc,end)-i2;
- gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr,Traits::LhsProgress, LhsStorageOrder,false>()
+ gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr,Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder,false>()
(blockA, lhs.getSubMapper(i2, actual_k2), actual_kc, actual_mc);
gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc,
@@ -235,10 +239,11 @@
// implements col-major += alpha * op(general) * op(triangular)
template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs, int Version>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int Version>
struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
LhsStorageOrder,ConjugateLhs,
- RhsStorageOrder,ConjugateRhs,ColMajor,Version>
+ RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
{
typedef gebp_traits<Scalar,Scalar> Traits;
enum {
@@ -251,20 +256,21 @@
Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* res, Index resStride,
+ Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs,
- int RhsStorageOrder, bool ConjugateRhs, int Version>
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride, int Version>
EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
LhsStorageOrder,ConjugateLhs,
- RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run(
+ RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>::run(
Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
- Scalar* _res, Index resStride,
+ Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{
const Index PacketBytes = packet_traits<Scalar>::size*sizeof(Scalar);
@@ -276,10 +282,10 @@
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@@ -299,7 +305,7 @@
triangularBuffer.diagonal().setOnes();
gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
- gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
+ gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel;
@@ -433,12 +439,12 @@
Mode, LhsIsTriangular,
(internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
(internal::traits<ActualRhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
- (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
+ (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor, Dest::InnerStrideAtCompileTime>
::run(
stripedRows, stripedCols, stripedDepth, // sizes
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
- &dst.coeffRef(0,0), dst.outerStride(), // result info
+ &dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(), // result info
actualAlpha, blocking
);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverMatrix.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverMatrix.h
index 223c38b..6d879ba 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverMatrix.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverMatrix.h
@@ -15,48 +15,48 @@
namespace internal {
// if the rhs is row major, let's transpose the product
-template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder>
-struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
+template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
{
static void run(
Index size, Index cols,
const Scalar* tri, Index triStride,
- Scalar* _other, Index otherStride,
+ Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking)
{
triangular_solve_matrix<
Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
(Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
NumTraits<Scalar>::IsComplex && Conjugate,
- TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
- ::run(size, cols, tri, triStride, _other, otherStride, blocking);
+ TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor, OtherInnerStride>
+ ::run(size, cols, tri, triStride, _other, otherIncr, otherStride, blocking);
}
};
/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
*/
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
-struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
+struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
{
static EIGEN_DONT_INLINE void run(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
- Scalar* _other, Index otherStride,
+ Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
-EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>::run(
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
- Scalar* _other, Index otherStride,
+ Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking)
{
Index cols = otherSize;
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
- typedef blas_data_mapper<Scalar, Index, ColMajor> OtherMapper;
+ typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
TriMapper tri(_tri, triStride);
- OtherMapper other(_other, otherStride);
+ OtherMapper other(_other, otherStride, otherIncr);
typedef gebp_traits<Scalar,Scalar> Traits;
@@ -76,7 +76,7 @@
conj_if<Conjugate> conj;
gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
- gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, TriStorageOrder> pack_lhs;
+ gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, TriStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
// the goal here is to subdivise the Rhs panels such that we keep some cache
@@ -128,19 +128,21 @@
{
Scalar b(0);
const Scalar* l = &tri(i,s);
- Scalar* r = &other(s,j);
+ typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
for (Index i3=0; i3<k; ++i3)
- b += conj(l[i3]) * r[i3];
+ b += conj(l[i3]) * r(i3);
other(i,j) = (other(i,j) - b)*a;
}
else
{
- Scalar b = (other(i,j) *= a);
- Scalar* r = &other(s,j);
- const Scalar* l = &tri(s,i);
+ Scalar& otherij = other(i,j);
+ otherij *= a;
+ Scalar b = otherij;
+ typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
+ typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
for (Index i3=0;i3<rs;++i3)
- r[i3] -= b * conj(l[i3]);
+ r(i3) -= b * conj(l(i3));
}
}
}
@@ -185,28 +187,28 @@
/* Optimized triangular solver with multiple left hand sides and the triangular matrix on the right
*/
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
-struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
{
static EIGEN_DONT_INLINE void run(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
- Scalar* _other, Index otherStride,
+ Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
-EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>::run(
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
- Scalar* _other, Index otherStride,
+ Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking)
{
Index rows = otherSize;
typedef typename NumTraits<Scalar>::Real RealScalar;
- typedef blas_data_mapper<Scalar, Index, ColMajor> LhsMapper;
+ typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
- LhsMapper lhs(_other, otherStride);
+ LhsMapper lhs(_other, otherStride, otherIncr);
RhsMapper rhs(_tri, triStride);
typedef gebp_traits<Scalar,Scalar> Traits;
@@ -229,7 +231,7 @@
gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
- gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, ColMajor, false, true> pack_lhs_panel;
+ gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor, false, true> pack_lhs_panel;
for(Index k2=IsLower ? size : 0;
IsLower ? k2>0 : k2<size;
@@ -297,24 +299,24 @@
{
Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
- Scalar* r = &lhs(i2,j);
+ typename LhsMapper::LinearMapper r = lhs.getLinearMapper(i2,j);
for (Index k3=0; k3<k; ++k3)
{
Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
- Scalar* a = &lhs(i2,IsLower ? j+1+k3 : absolute_j2+k3);
+ typename LhsMapper::LinearMapper a = lhs.getLinearMapper(i2,IsLower ? j+1+k3 : absolute_j2+k3);
for (Index i=0; i<actual_mc; ++i)
- r[i] -= a[i] * b;
+ r(i) -= a(i) * b;
}
if((Mode & UnitDiag)==0)
{
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
for (Index i=0; i<actual_mc; ++i)
- r[i] *= inv_rjj;
+ r(i) *= inv_rjj;
}
}
// pack the just computed part of lhs to A
- pack_lhs_panel(blockA, LhsMapper(_other+absolute_j2*otherStride+i2, otherStride),
+ pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
actualPanelWidth, actual_mc,
actual_kc, j2);
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverVector.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverVector.h
index b994759..6473170 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverVector.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/products/TriangularSolverVector.h
@@ -58,7 +58,7 @@
{
// let's directly call the low level product function because:
// 1 - it is faster to compile
- // 2 - it is slighlty faster at runtime
+ // 2 - it is slightly faster at runtime
Index startRow = IsLower ? pi : pi-actualPanelWidth;
Index startCol = IsLower ? 0 : pi;
@@ -77,7 +77,7 @@
if (k>0)
rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<const Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
- if(!(Mode & UnitDiag))
+ if((!(Mode & UnitDiag)) && numext::not_equal_strict(rhs[i],RhsScalar(0)))
rhs[i] /= cjLhs(i,i);
}
}
@@ -114,20 +114,23 @@
for(Index k=0; k<actualPanelWidth; ++k)
{
Index i = IsLower ? pi+k : pi-k-1;
- if(!(Mode & UnitDiag))
- rhs[i] /= cjLhs.coeff(i,i);
+ if(numext::not_equal_strict(rhs[i],RhsScalar(0)))
+ {
+ if(!(Mode & UnitDiag))
+ rhs[i] /= cjLhs.coeff(i,i);
- Index r = actualPanelWidth - k - 1; // remaining size
- Index s = IsLower ? i+1 : i-r;
- if (r>0)
- Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
+ Index r = actualPanelWidth - k - 1; // remaining size
+ Index s = IsLower ? i+1 : i-r;
+ if (r>0)
+ Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
+ }
}
Index r = IsLower ? size - endBlock : startBlock; // remaining size
if (r > 0)
{
// let's directly call the low level product function because:
// 1 - it is faster to compile
- // 2 - it is slighlty faster at runtime
+ // 2 - it is slightly faster at runtime
general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
r, actualPanelWidth,
LhsMapper(&lhs.coeffRef(endBlock,startBlock), lhsStride),
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/BlasUtil.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/BlasUtil.h
index 6e6ee11..e16a564 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/BlasUtil.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/BlasUtil.h
@@ -24,14 +24,14 @@
template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
struct gemm_pack_rhs;
-template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
+template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
struct gemm_pack_lhs;
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
- int ResStorageOrder>
+ int ResStorageOrder, int ResInnerStride>
struct general_matrix_matrix_product;
template<typename Index,
@@ -39,90 +39,6 @@
typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
struct general_matrix_vector_product;
-
-template<bool Conjugate> struct conj_if;
-
-template<> struct conj_if<true> {
- template<typename T>
- inline T operator()(const T& x) const { return numext::conj(x); }
- template<typename T>
- inline T pconj(const T& x) const { return internal::pconj(x); }
-};
-
-template<> struct conj_if<false> {
- template<typename T>
- inline const T& operator()(const T& x) const { return x; }
- template<typename T>
- inline const T& pconj(const T& x) const { return x; }
-};
-
-// Generic implementation for custom complex types.
-template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs>
-struct conj_helper
-{
- typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;
-
- EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const
- { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
-};
-
-template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
-{
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
-};
-
-template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
-{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
- { return c + pmul(x,y); }
-
- EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
- { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
-};
-
-template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
-{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
- { return c + pmul(x,y); }
-
- EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
- { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
-};
-
-template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
-{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
- { return c + pmul(x,y); }
-
- EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
- { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
-};
-
-template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
-{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
- { return padd(c, pmul(x,y)); }
- EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
- { return conj_if<Conj>()(x)*y; }
-};
-
-template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
-{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
- { return padd(c, pmul(x,y)); }
- EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
- { return x*conj_if<Conj>()(y); }
-};
-
template<typename From,typename To> struct get_factor {
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
};
@@ -155,13 +71,19 @@
Scalar* m_data;
};
-template<typename Scalar, typename Index, int AlignmentType>
-class BlasLinearMapper {
- public:
- typedef typename packet_traits<Scalar>::type Packet;
- typedef typename packet_traits<Scalar>::half HalfPacket;
+template<typename Scalar, typename Index, int AlignmentType, int Incr=1>
+class BlasLinearMapper;
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
+template<typename Scalar, typename Index, int AlignmentType>
+class BlasLinearMapper<Scalar,Index,AlignmentType>
+{
+public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data, Index incr=1)
+ : m_data(data)
+ {
+ EIGEN_ONLY_USED_FOR_DEBUG(incr);
+ eigen_assert(incr==1);
+ }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
internal::prefetch(&operator()(i));
@@ -171,33 +93,86 @@
return m_data[i];
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
- return ploadt<Packet, AlignmentType>(m_data + i);
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
+ return ploadt<PacketType, AlignmentType>(m_data + i);
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
- return ploadt<HalfPacket, AlignmentType>(m_data + i);
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const {
+ pstoret<Scalar, PacketType, AlignmentType>(m_data + i, p);
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
- pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
- }
-
- protected:
+protected:
Scalar *m_data;
};
// Lightweight helper class to access matrix coefficients.
-template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
-class blas_data_mapper {
- public:
- typedef typename packet_traits<Scalar>::type Packet;
- typedef typename packet_traits<Scalar>::half HalfPacket;
+template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1>
+class blas_data_mapper;
+// TMP to help PacketBlock store implementation.
+// There's currently no known use case for PacketBlock load.
+// The default implementation assumes ColMajor order.
+// It always store each packet sequentially one `stride` apart.
+template<typename Index, typename Scalar, typename Packet, int n, int idx, int StorageOrder>
+struct PacketBlockManagement
+{
+ PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, StorageOrder> pbm;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const {
+ pbm.store(to, stride, i, j, block);
+ pstoreu<Scalar>(to + i + (j + idx)*stride, block.packet[idx]);
+ }
+};
+
+// PacketBlockManagement specialization to take care of RowMajor order without ifs.
+template<typename Index, typename Scalar, typename Packet, int n, int idx>
+struct PacketBlockManagement<Index, Scalar, Packet, n, idx, RowMajor>
+{
+ PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, RowMajor> pbm;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const {
+ pbm.store(to, stride, i, j, block);
+ pstoreu<Scalar>(to + j + (i + idx)*stride, block.packet[idx]);
+ }
+};
+
+template<typename Index, typename Scalar, typename Packet, int n, int StorageOrder>
+struct PacketBlockManagement<Index, Scalar, Packet, n, -1, StorageOrder>
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const {
+ EIGEN_UNUSED_VARIABLE(to);
+ EIGEN_UNUSED_VARIABLE(stride);
+ EIGEN_UNUSED_VARIABLE(i);
+ EIGEN_UNUSED_VARIABLE(j);
+ EIGEN_UNUSED_VARIABLE(block);
+ }
+};
+
+template<typename Index, typename Scalar, typename Packet, int n>
+struct PacketBlockManagement<Index, Scalar, Packet, n, -1, RowMajor>
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const {
+ EIGEN_UNUSED_VARIABLE(to);
+ EIGEN_UNUSED_VARIABLE(stride);
+ EIGEN_UNUSED_VARIABLE(i);
+ EIGEN_UNUSED_VARIABLE(j);
+ EIGEN_UNUSED_VARIABLE(block);
+ }
+};
+
+template<typename Scalar, typename Index, int StorageOrder, int AlignmentType>
+class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
+{
+public:
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
+ : m_data(data), m_stride(stride)
+ {
+ EIGEN_ONLY_USED_FOR_DEBUG(incr);
+ eigen_assert(incr==1);
+ }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
getSubMapper(Index i, Index j) const {
@@ -218,12 +193,14 @@
return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
- return ploadt<Packet, AlignmentType>(&operator()(i, j));
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
+ return ploadt<PacketType, AlignmentType>(&operator()(i, j));
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
- return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
+ template <typename PacketT, int AlignmentT>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
+ return ploadt<PacketT, AlignmentT>(&operator()(i, j));
}
template<typename SubPacket>
@@ -246,11 +223,167 @@
return internal::first_default_aligned(m_data, size);
}
- protected:
+ template<typename SubPacket, int n>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j, const PacketBlock<SubPacket, n> &block) const {
+ PacketBlockManagement<Index, Scalar, SubPacket, n, n-1, StorageOrder> pbm;
+ pbm.store(m_data, m_stride, i, j, block);
+ }
+protected:
Scalar* EIGEN_RESTRICT m_data;
const Index m_stride;
};
+// Implementation of non-natural increment (i.e. inner-stride != 1)
+// The exposed API is not complete yet compared to the Incr==1 case
+// because some features makes less sense in this case.
+template<typename Scalar, typename Index, int AlignmentType, int Incr>
+class BlasLinearMapper
+{
+public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,Index incr) : m_data(data), m_incr(incr) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
+ internal::prefetch(&operator()(i));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
+ return m_data[i*m_incr.value()];
+ }
+
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
+ return pgather<Scalar,PacketType>(m_data + i*m_incr.value(), m_incr.value());
+ }
+
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const {
+ pscatter<Scalar, PacketType>(m_data + i*m_incr.value(), p, m_incr.value());
+ }
+
+protected:
+ Scalar *m_data;
+ const internal::variable_if_dynamic<Index,Incr> m_incr;
+};
+
+template<typename Scalar, typename Index, int StorageOrder, int AlignmentType,int Incr>
+class blas_data_mapper
+{
+public:
+ typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
+ getSubMapper(Index i, Index j) const {
+ return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value());
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(&operator()(i, j), m_incr.value());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
+ return m_data[StorageOrder==RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride];
+ }
+
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
+ return pgather<Scalar,PacketType>(&operator()(i, j),m_incr.value());
+ }
+
+ template <typename PacketT, int AlignmentT>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
+ return pgather<Scalar,PacketT>(&operator()(i, j),m_incr.value());
+ }
+
+ template<typename SubPacket>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
+ pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
+ }
+
+ template<typename SubPacket>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
+ return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
+ }
+
+ // storePacketBlock_helper defines a way to access values inside the PacketBlock, this is essentially required by the Complex types.
+ template<typename SubPacket, typename ScalarT, int n, int idx>
+ struct storePacketBlock_helper
+ {
+ storePacketBlock_helper<SubPacket, ScalarT, n, idx-1> spbh;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
+ spbh.store(sup, i,j,block);
+ for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
+ {
+ ScalarT *v = &sup->operator()(i+l, j+idx);
+ *v = block.packet[idx][l];
+ }
+ }
+ };
+
+ template<typename SubPacket, int n, int idx>
+ struct storePacketBlock_helper<SubPacket, std::complex<float>, n, idx>
+ {
+ storePacketBlock_helper<SubPacket, std::complex<float>, n, idx-1> spbh;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
+ spbh.store(sup,i,j,block);
+ for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
+ {
+ std::complex<float> *v = &sup->operator()(i+l, j+idx);
+ v->real(block.packet[idx].v[2*l+0]);
+ v->imag(block.packet[idx].v[2*l+1]);
+ }
+ }
+ };
+
+ template<typename SubPacket, int n, int idx>
+ struct storePacketBlock_helper<SubPacket, std::complex<double>, n, idx>
+ {
+ storePacketBlock_helper<SubPacket, std::complex<double>, n, idx-1> spbh;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
+ spbh.store(sup,i,j,block);
+ for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
+ {
+ std::complex<double> *v = &sup->operator()(i+l, j+idx);
+ v->real(block.packet[idx].v[2*l+0]);
+ v->imag(block.packet[idx].v[2*l+1]);
+ }
+ }
+ };
+
+ template<typename SubPacket, typename ScalarT, int n>
+ struct storePacketBlock_helper<SubPacket, ScalarT, n, -1>
+ {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const {
+ }
+ };
+
+ template<typename SubPacket, int n>
+ struct storePacketBlock_helper<SubPacket, std::complex<float>, n, -1>
+ {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const {
+ }
+ };
+
+ template<typename SubPacket, int n>
+ struct storePacketBlock_helper<SubPacket, std::complex<double>, n, -1>
+ {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const {
+ }
+ };
+ // This function stores a PacketBlock on m_data, this approach is really quite slow compare to Incr=1 and should be avoided when possible.
+ template<typename SubPacket, int n>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j, const PacketBlock<SubPacket, n>&block) const {
+ storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb;
+ spb.store(this, i,j,block);
+ }
+protected:
+ Scalar* EIGEN_RESTRICT m_data;
+ const Index m_stride;
+ const internal::variable_if_dynamic<Index,Incr> m_incr;
+};
+
// lightweight helper class to access matrix coefficients (const version)
template<typename Scalar, typename Index, int StorageOrder>
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
@@ -278,14 +411,15 @@
HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
&& ( bool(XprType::IsVectorAtCompileTime)
|| int(inner_stride_at_compile_time<XprType>::ret) == 1)
- ) ? 1 : 0
+ ) ? 1 : 0,
+ HasScalarFactor = false
};
typedef typename conditional<bool(HasUsableDirectAccess),
ExtractType,
typename _ExtractType::PlainObject
>::type DirectLinearAccessType;
- static inline ExtractType extract(const XprType& x) { return x; }
- static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
+ static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return x; }
+ static inline EIGEN_DEVICE_FUNC const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
};
// pop conjugate
@@ -310,17 +444,23 @@
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
: blas_traits<NestedXpr>
{
+ enum {
+ HasScalarFactor = true
+ };
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
typedef typename Base::ExtractType ExtractType;
- static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
- static inline Scalar extractScalarFactor(const XprType& x)
+ static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
+ static inline EIGEN_DEVICE_FUNC Scalar extractScalarFactor(const XprType& x)
{ return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
};
template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
: blas_traits<NestedXpr>
{
+ enum {
+ HasScalarFactor = true
+ };
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
typedef typename Base::ExtractType ExtractType;
@@ -339,6 +479,9 @@
struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
: blas_traits<NestedXpr>
{
+ enum {
+ HasScalarFactor = true
+ };
typedef blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ExtractType ExtractType;
@@ -375,7 +518,7 @@
template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
struct extract_data_selector {
- static const typename T::Scalar* run(const T& m)
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename T::Scalar* run(const T& m)
{
return blas_traits<T>::extract(m).data();
}
@@ -386,11 +529,53 @@
static typename T::Scalar* run(const T&) { return 0; }
};
-template<typename T> const typename T::Scalar* extract_data(const T& m)
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename T::Scalar* extract_data(const T& m)
{
return extract_data_selector<T>::run(m);
}
+/**
+ * \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products.
+ * There is a specialization for booleans
+ */
+template<typename ResScalar, typename Lhs, typename Rhs>
+struct combine_scalar_factors_impl
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs)
+ {
+ return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
+ {
+ return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+};
+template<typename Lhs, typename Rhs>
+struct combine_scalar_factors_impl<bool, Lhs, Rhs>
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs)
+ {
+ return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs)
+ {
+ return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+};
+
+template<typename ResScalar, typename Lhs, typename Rhs>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
+{
+ return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(alpha, lhs, rhs);
+}
+template<typename ResScalar, typename Lhs, typename Rhs>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs)
+{
+ return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(lhs, rhs);
+}
+
+
} // end namespace internal
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ConfigureVectorization.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ConfigureVectorization.h
new file mode 100644
index 0000000..af4e696
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ConfigureVectorization.h
@@ -0,0 +1,512 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2008-2018 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2020, Arm Limited and Contributors
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CONFIGURE_VECTORIZATION_H
+#define EIGEN_CONFIGURE_VECTORIZATION_H
+
+//------------------------------------------------------------------------------------------
+// Static and dynamic alignment control
+//
+// The main purpose of this section is to define EIGEN_MAX_ALIGN_BYTES and EIGEN_MAX_STATIC_ALIGN_BYTES
+// as the maximal boundary in bytes on which dynamically and statically allocated data may be alignment respectively.
+// The values of EIGEN_MAX_ALIGN_BYTES and EIGEN_MAX_STATIC_ALIGN_BYTES can be specified by the user. If not,
+// a default value is automatically computed based on architecture, compiler, and OS.
+//
+// This section also defines macros EIGEN_ALIGN_TO_BOUNDARY(N) and the shortcuts EIGEN_ALIGN{8,16,32,_MAX}
+// to be used to declare statically aligned buffers.
+//------------------------------------------------------------------------------------------
+
+
+/* EIGEN_ALIGN_TO_BOUNDARY(n) forces data to be n-byte aligned. This is used to satisfy SIMD requirements.
+ * However, we do that EVEN if vectorization (EIGEN_VECTORIZE) is disabled,
+ * so that vectorization doesn't affect binary compatibility.
+ *
+ * If we made alignment depend on whether or not EIGEN_VECTORIZE is defined, it would be impossible to link
+ * vectorized and non-vectorized code.
+ *
+ * FIXME: this code can be cleaned up once we switch to proper C++11 only.
+ */
+#if (defined EIGEN_CUDACC)
+ #define EIGEN_ALIGN_TO_BOUNDARY(n) __align__(n)
+ #define EIGEN_ALIGNOF(x) __alignof(x)
+#elif EIGEN_HAS_ALIGNAS
+ #define EIGEN_ALIGN_TO_BOUNDARY(n) alignas(n)
+ #define EIGEN_ALIGNOF(x) alignof(x)
+#elif EIGEN_COMP_GNUC || EIGEN_COMP_PGI || EIGEN_COMP_IBM || EIGEN_COMP_ARM
+ #define EIGEN_ALIGN_TO_BOUNDARY(n) __attribute__((aligned(n)))
+ #define EIGEN_ALIGNOF(x) __alignof(x)
+#elif EIGEN_COMP_MSVC
+ #define EIGEN_ALIGN_TO_BOUNDARY(n) __declspec(align(n))
+ #define EIGEN_ALIGNOF(x) __alignof(x)
+#elif EIGEN_COMP_SUNCC
+ // FIXME not sure about this one:
+ #define EIGEN_ALIGN_TO_BOUNDARY(n) __attribute__((aligned(n)))
+ #define EIGEN_ALIGNOF(x) __alignof(x)
+#else
+ #error Please tell me what is the equivalent of alignas(n) and alignof(x) for your compiler
+#endif
+
+// If the user explicitly disable vectorization, then we also disable alignment
+#if defined(EIGEN_DONT_VECTORIZE)
+ #if defined(EIGEN_GPUCC)
+ // GPU code is always vectorized and requires memory alignment for
+ // statically allocated buffers.
+ #define EIGEN_IDEAL_MAX_ALIGN_BYTES 16
+ #else
+ #define EIGEN_IDEAL_MAX_ALIGN_BYTES 0
+ #endif
+#elif defined(__AVX512F__)
+ // 64 bytes static alignment is preferred only if really required
+ #define EIGEN_IDEAL_MAX_ALIGN_BYTES 64
+#elif defined(__AVX__)
+ // 32 bytes static alignment is preferred only if really required
+ #define EIGEN_IDEAL_MAX_ALIGN_BYTES 32
+#else
+ #define EIGEN_IDEAL_MAX_ALIGN_BYTES 16
+#endif
+
+
+// EIGEN_MIN_ALIGN_BYTES defines the minimal value for which the notion of explicit alignment makes sense
+#define EIGEN_MIN_ALIGN_BYTES 16
+
+// Defined the boundary (in bytes) on which the data needs to be aligned. Note
+// that unless EIGEN_ALIGN is defined and not equal to 0, the data may not be
+// aligned at all regardless of the value of this #define.
+
+#if (defined(EIGEN_DONT_ALIGN_STATICALLY) || defined(EIGEN_DONT_ALIGN)) && defined(EIGEN_MAX_STATIC_ALIGN_BYTES) && EIGEN_MAX_STATIC_ALIGN_BYTES>0
+#error EIGEN_MAX_STATIC_ALIGN_BYTES and EIGEN_DONT_ALIGN[_STATICALLY] are both defined with EIGEN_MAX_STATIC_ALIGN_BYTES!=0. Use EIGEN_MAX_STATIC_ALIGN_BYTES=0 as a synonym of EIGEN_DONT_ALIGN_STATICALLY.
+#endif
+
+// EIGEN_DONT_ALIGN_STATICALLY and EIGEN_DONT_ALIGN are deprecated
+// They imply EIGEN_MAX_STATIC_ALIGN_BYTES=0
+#if defined(EIGEN_DONT_ALIGN_STATICALLY) || defined(EIGEN_DONT_ALIGN)
+ #ifdef EIGEN_MAX_STATIC_ALIGN_BYTES
+ #undef EIGEN_MAX_STATIC_ALIGN_BYTES
+ #endif
+ #define EIGEN_MAX_STATIC_ALIGN_BYTES 0
+#endif
+
+#ifndef EIGEN_MAX_STATIC_ALIGN_BYTES
+
+ // Try to automatically guess what is the best default value for EIGEN_MAX_STATIC_ALIGN_BYTES
+
+ // 16 byte alignment is only useful for vectorization. Since it affects the ABI, we need to enable
+ // 16 byte alignment on all platforms where vectorization might be enabled. In theory we could always
+ // enable alignment, but it can be a cause of problems on some platforms, so we just disable it in
+ // certain common platform (compiler+architecture combinations) to avoid these problems.
+ // Only static alignment is really problematic (relies on nonstandard compiler extensions),
+ // try to keep heap alignment even when we have to disable static alignment.
+ #if EIGEN_COMP_GNUC && !(EIGEN_ARCH_i386_OR_x86_64 || EIGEN_ARCH_ARM_OR_ARM64 || EIGEN_ARCH_PPC || EIGEN_ARCH_IA64 || EIGEN_ARCH_MIPS)
+ #define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 1
+ #elif EIGEN_ARCH_ARM_OR_ARM64 && EIGEN_COMP_GNUC_STRICT && EIGEN_GNUC_AT_MOST(4, 6)
+ // Old versions of GCC on ARM, at least 4.4, were once seen to have buggy static alignment support.
+ // Not sure which version fixed it, hopefully it doesn't affect 4.7, which is still somewhat in use.
+ // 4.8 and newer seem definitely unaffected.
+ #define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 1
+ #else
+ #define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 0
+ #endif
+
+ // static alignment is completely disabled with GCC 3, Sun Studio, and QCC/QNX
+ #if !EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT \
+ && !EIGEN_GCC3_OR_OLDER \
+ && !EIGEN_COMP_SUNCC \
+ && !EIGEN_OS_QNX
+ #define EIGEN_ARCH_WANTS_STACK_ALIGNMENT 1
+ #else
+ #define EIGEN_ARCH_WANTS_STACK_ALIGNMENT 0
+ #endif
+
+ #if EIGEN_ARCH_WANTS_STACK_ALIGNMENT
+ #define EIGEN_MAX_STATIC_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
+ #else
+ #define EIGEN_MAX_STATIC_ALIGN_BYTES 0
+ #endif
+
+#endif
+
+// If EIGEN_MAX_ALIGN_BYTES is defined, then it is considered as an upper bound for EIGEN_MAX_STATIC_ALIGN_BYTES
+#if defined(EIGEN_MAX_ALIGN_BYTES) && EIGEN_MAX_ALIGN_BYTES<EIGEN_MAX_STATIC_ALIGN_BYTES
+#undef EIGEN_MAX_STATIC_ALIGN_BYTES
+#define EIGEN_MAX_STATIC_ALIGN_BYTES EIGEN_MAX_ALIGN_BYTES
+#endif
+
+#if EIGEN_MAX_STATIC_ALIGN_BYTES==0 && !defined(EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT)
+ #define EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT
+#endif
+
+// At this stage, EIGEN_MAX_STATIC_ALIGN_BYTES>0 is the true test whether we want to align arrays on the stack or not.
+// It takes into account both the user choice to explicitly enable/disable alignment (by setting EIGEN_MAX_STATIC_ALIGN_BYTES)
+// and the architecture config (EIGEN_ARCH_WANTS_STACK_ALIGNMENT).
+// Henceforth, only EIGEN_MAX_STATIC_ALIGN_BYTES should be used.
+
+
+// Shortcuts to EIGEN_ALIGN_TO_BOUNDARY
+#define EIGEN_ALIGN8 EIGEN_ALIGN_TO_BOUNDARY(8)
+#define EIGEN_ALIGN16 EIGEN_ALIGN_TO_BOUNDARY(16)
+#define EIGEN_ALIGN32 EIGEN_ALIGN_TO_BOUNDARY(32)
+#define EIGEN_ALIGN64 EIGEN_ALIGN_TO_BOUNDARY(64)
+#if EIGEN_MAX_STATIC_ALIGN_BYTES>0
+#define EIGEN_ALIGN_MAX EIGEN_ALIGN_TO_BOUNDARY(EIGEN_MAX_STATIC_ALIGN_BYTES)
+#else
+#define EIGEN_ALIGN_MAX
+#endif
+
+
+// Dynamic alignment control
+
+#if defined(EIGEN_DONT_ALIGN) && defined(EIGEN_MAX_ALIGN_BYTES) && EIGEN_MAX_ALIGN_BYTES>0
+#error EIGEN_MAX_ALIGN_BYTES and EIGEN_DONT_ALIGN are both defined with EIGEN_MAX_ALIGN_BYTES!=0. Use EIGEN_MAX_ALIGN_BYTES=0 as a synonym of EIGEN_DONT_ALIGN.
+#endif
+
+#ifdef EIGEN_DONT_ALIGN
+ #ifdef EIGEN_MAX_ALIGN_BYTES
+ #undef EIGEN_MAX_ALIGN_BYTES
+ #endif
+ #define EIGEN_MAX_ALIGN_BYTES 0
+#elif !defined(EIGEN_MAX_ALIGN_BYTES)
+ #define EIGEN_MAX_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
+#endif
+
+#if EIGEN_IDEAL_MAX_ALIGN_BYTES > EIGEN_MAX_ALIGN_BYTES
+#define EIGEN_DEFAULT_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
+#else
+#define EIGEN_DEFAULT_ALIGN_BYTES EIGEN_MAX_ALIGN_BYTES
+#endif
+
+
+#ifndef EIGEN_UNALIGNED_VECTORIZE
+#define EIGEN_UNALIGNED_VECTORIZE 1
+#endif
+
+//----------------------------------------------------------------------
+
+// if alignment is disabled, then disable vectorization. Note: EIGEN_MAX_ALIGN_BYTES is the proper check, it takes into
+// account both the user's will (EIGEN_MAX_ALIGN_BYTES,EIGEN_DONT_ALIGN) and our own platform checks
+#if EIGEN_MAX_ALIGN_BYTES==0
+ #ifndef EIGEN_DONT_VECTORIZE
+ #define EIGEN_DONT_VECTORIZE
+ #endif
+#endif
+
+
+// The following (except #include <malloc.h> and _M_IX86_FP ??) can likely be
+// removed as gcc 4.1 and msvc 2008 are not supported anyways.
+#if EIGEN_COMP_MSVC
+ #include <malloc.h> // for _aligned_malloc -- need it regardless of whether vectorization is enabled
+ #if (EIGEN_COMP_MSVC >= 1500) // 2008 or later
+ // a user reported that in 64-bit mode, MSVC doesn't care to define _M_IX86_FP.
+ #if (defined(_M_IX86_FP) && (_M_IX86_FP >= 2)) || EIGEN_ARCH_x86_64
+ #define EIGEN_SSE2_ON_MSVC_2008_OR_LATER
+ #endif
+ #endif
+#else
+ #if (defined __SSE2__) && ( (!EIGEN_COMP_GNUC) || EIGEN_COMP_ICC || EIGEN_GNUC_AT_LEAST(4,2) )
+ #define EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC
+ #endif
+#endif
+
+#if !(defined(EIGEN_DONT_VECTORIZE) || defined(EIGEN_GPUCC))
+
+ #if defined (EIGEN_SSE2_ON_NON_MSVC_BUT_NOT_OLD_GCC) || defined(EIGEN_SSE2_ON_MSVC_2008_OR_LATER)
+
+ // Defines symbols for compile-time detection of which instructions are
+ // used.
+ // EIGEN_VECTORIZE_YY is defined if and only if the instruction set YY is used
+ #define EIGEN_VECTORIZE
+ #define EIGEN_VECTORIZE_SSE
+ #define EIGEN_VECTORIZE_SSE2
+
+ // Detect sse3/ssse3/sse4:
+ // gcc and icc defines __SSE3__, ...
+ // there is no way to know about this on msvc. You can define EIGEN_VECTORIZE_SSE* if you
+ // want to force the use of those instructions with msvc.
+ #ifdef __SSE3__
+ #define EIGEN_VECTORIZE_SSE3
+ #endif
+ #ifdef __SSSE3__
+ #define EIGEN_VECTORIZE_SSSE3
+ #endif
+ #ifdef __SSE4_1__
+ #define EIGEN_VECTORIZE_SSE4_1
+ #endif
+ #ifdef __SSE4_2__
+ #define EIGEN_VECTORIZE_SSE4_2
+ #endif
+ #ifdef __AVX__
+ #ifndef EIGEN_USE_SYCL
+ #define EIGEN_VECTORIZE_AVX
+ #endif
+ #define EIGEN_VECTORIZE_SSE3
+ #define EIGEN_VECTORIZE_SSSE3
+ #define EIGEN_VECTORIZE_SSE4_1
+ #define EIGEN_VECTORIZE_SSE4_2
+ #endif
+ #ifdef __AVX2__
+ #ifndef EIGEN_USE_SYCL
+ #define EIGEN_VECTORIZE_AVX2
+ #define EIGEN_VECTORIZE_AVX
+ #endif
+ #define EIGEN_VECTORIZE_SSE3
+ #define EIGEN_VECTORIZE_SSSE3
+ #define EIGEN_VECTORIZE_SSE4_1
+ #define EIGEN_VECTORIZE_SSE4_2
+ #endif
+ #if defined(__FMA__) || (EIGEN_COMP_MSVC && defined(__AVX2__))
+ // MSVC does not expose a switch dedicated for FMA
+ // For MSVC, AVX2 => FMA
+ #define EIGEN_VECTORIZE_FMA
+ #endif
+ #if defined(__AVX512F__)
+ #ifndef EIGEN_VECTORIZE_FMA
+ #if EIGEN_COMP_GNUC
+ #error Please add -mfma to your compiler flags: compiling with -mavx512f alone without SSE/AVX FMA is not supported (bug 1638).
+ #else
+ #error Please enable FMA in your compiler flags (e.g. -mfma): compiling with AVX512 alone without SSE/AVX FMA is not supported (bug 1638).
+ #endif
+ #endif
+ #ifndef EIGEN_USE_SYCL
+ #define EIGEN_VECTORIZE_AVX512
+ #define EIGEN_VECTORIZE_AVX2
+ #define EIGEN_VECTORIZE_AVX
+ #endif
+ #define EIGEN_VECTORIZE_FMA
+ #define EIGEN_VECTORIZE_SSE3
+ #define EIGEN_VECTORIZE_SSSE3
+ #define EIGEN_VECTORIZE_SSE4_1
+ #define EIGEN_VECTORIZE_SSE4_2
+ #ifndef EIGEN_USE_SYCL
+ #ifdef __AVX512DQ__
+ #define EIGEN_VECTORIZE_AVX512DQ
+ #endif
+ #ifdef __AVX512ER__
+ #define EIGEN_VECTORIZE_AVX512ER
+ #endif
+ #ifdef __AVX512BF16__
+ #define EIGEN_VECTORIZE_AVX512BF16
+ #endif
+ #endif
+ #endif
+
+ // Disable AVX support on broken xcode versions
+ #if defined(__apple_build_version__) && (__apple_build_version__ == 11000033 ) && ( __MAC_OS_X_VERSION_MIN_REQUIRED == 101500 )
+ // A nasty bug in the clang compiler shipped with xcode in a common compilation situation
+ // when XCode 11.0 and Mac deployment target macOS 10.15 is https://trac.macports.org/ticket/58776#no1
+ #ifdef EIGEN_VECTORIZE_AVX
+ #undef EIGEN_VECTORIZE_AVX
+ #warning "Disabling AVX support: clang compiler shipped with XCode 11.[012] generates broken assembly with -macosx-version-min=10.15 and AVX enabled. "
+ #ifdef EIGEN_VECTORIZE_AVX2
+ #undef EIGEN_VECTORIZE_AVX2
+ #endif
+ #ifdef EIGEN_VECTORIZE_FMA
+ #undef EIGEN_VECTORIZE_FMA
+ #endif
+ #ifdef EIGEN_VECTORIZE_AVX512
+ #undef EIGEN_VECTORIZE_AVX512
+ #endif
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
+ #undef EIGEN_VECTORIZE_AVX512DQ
+ #endif
+ #ifdef EIGEN_VECTORIZE_AVX512ER
+ #undef EIGEN_VECTORIZE_AVX512ER
+ #endif
+ #endif
+ // NOTE: Confirmed test failures in XCode 11.0, and XCode 11.2 with -macosx-version-min=10.15 and AVX
+ // NOTE using -macosx-version-min=10.15 with Xcode 11.0 results in runtime segmentation faults in many tests, 11.2 produce core dumps in 3 tests
+ // NOTE using -macosx-version-min=10.14 produces functioning and passing tests in all cases
+ // NOTE __clang_version__ "11.0.0 (clang-1100.0.33.8)" XCode 11.0 <- Produces many segfault and core dumping tests
+ // with -macosx-version-min=10.15 and AVX
+ // NOTE __clang_version__ "11.0.0 (clang-1100.0.33.12)" XCode 11.2 <- Produces 3 core dumping tests with
+ // -macosx-version-min=10.15 and AVX
+ #endif
+
+ // include files
+
+ // This extern "C" works around a MINGW-w64 compilation issue
+ // https://sourceforge.net/tracker/index.php?func=detail&aid=3018394&group_id=202880&atid=983354
+ // In essence, intrin.h is included by windows.h and also declares intrinsics (just as emmintrin.h etc. below do).
+ // However, intrin.h uses an extern "C" declaration, and g++ thus complains of duplicate declarations
+ // with conflicting linkage. The linkage for intrinsics doesn't matter, but at that stage the compiler doesn't know;
+ // so, to avoid compile errors when windows.h is included after Eigen/Core, ensure intrinsics are extern "C" here too.
+ // notice that since these are C headers, the extern "C" is theoretically needed anyways.
+ extern "C" {
+ // In theory we should only include immintrin.h and not the other *mmintrin.h header files directly.
+ // Doing so triggers some issues with ICC. However old gcc versions seems to not have this file, thus:
+ #if EIGEN_COMP_ICC >= 1110
+ #include <immintrin.h>
+ #else
+ #include <mmintrin.h>
+ #include <emmintrin.h>
+ #include <xmmintrin.h>
+ #ifdef EIGEN_VECTORIZE_SSE3
+ #include <pmmintrin.h>
+ #endif
+ #ifdef EIGEN_VECTORIZE_SSSE3
+ #include <tmmintrin.h>
+ #endif
+ #ifdef EIGEN_VECTORIZE_SSE4_1
+ #include <smmintrin.h>
+ #endif
+ #ifdef EIGEN_VECTORIZE_SSE4_2
+ #include <nmmintrin.h>
+ #endif
+ #if defined(EIGEN_VECTORIZE_AVX) || defined(EIGEN_VECTORIZE_AVX512)
+ #include <immintrin.h>
+ #endif
+ #endif
+ } // end extern "C"
+
+ #elif defined __VSX__
+
+ #define EIGEN_VECTORIZE
+ #define EIGEN_VECTORIZE_VSX
+ #include <altivec.h>
+ // We need to #undef all these ugly tokens defined in <altivec.h>
+ // => use __vector instead of vector
+ #undef bool
+ #undef vector
+ #undef pixel
+
+ #elif defined __ALTIVEC__
+
+ #define EIGEN_VECTORIZE
+ #define EIGEN_VECTORIZE_ALTIVEC
+ #include <altivec.h>
+ // We need to #undef all these ugly tokens defined in <altivec.h>
+ // => use __vector instead of vector
+ #undef bool
+ #undef vector
+ #undef pixel
+
+ #elif ((defined __ARM_NEON) || (defined __ARM_NEON__)) && !(defined EIGEN_ARM64_USE_SVE)
+
+ #define EIGEN_VECTORIZE
+ #define EIGEN_VECTORIZE_NEON
+ #include <arm_neon.h>
+
+ // We currently require SVE to be enabled explicitly via EIGEN_ARM64_USE_SVE and
+ // will not select the backend automatically
+ #elif (defined __ARM_FEATURE_SVE) && (defined EIGEN_ARM64_USE_SVE)
+
+ #define EIGEN_VECTORIZE
+ #define EIGEN_VECTORIZE_SVE
+ #include <arm_sve.h>
+
+ // Since we depend on knowing SVE vector lengths at compile-time, we need
+ // to ensure a fixed lengths is set
+ #if defined __ARM_FEATURE_SVE_BITS
+ #define EIGEN_ARM64_SVE_VL __ARM_FEATURE_SVE_BITS
+ #else
+#error "Eigen requires a fixed SVE lector length but EIGEN_ARM64_SVE_VL is not set."
+#endif
+
+#elif (defined __s390x__ && defined __VEC__)
+
+#define EIGEN_VECTORIZE
+#define EIGEN_VECTORIZE_ZVECTOR
+#include <vecintrin.h>
+
+#elif defined __mips_msa
+
+// Limit MSA optimizations to little-endian CPUs for now.
+// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#if defined(__LP64__)
+#define EIGEN_MIPS_64
+#else
+#define EIGEN_MIPS_32
+#endif
+#define EIGEN_VECTORIZE
+#define EIGEN_VECTORIZE_MSA
+#include <msa.h>
+#endif
+
+#endif
+#endif
+
+// Following the Arm ACLE arm_neon.h should also include arm_fp16.h but not all
+// compilers seem to follow this. We therefore include it explicitly.
+// See also: https://bugs.llvm.org/show_bug.cgi?id=47955
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ #include <arm_fp16.h>
+#endif
+
+#if defined(__F16C__) && (!defined(EIGEN_GPUCC) && (!defined(EIGEN_COMP_CLANG) || EIGEN_COMP_CLANG>=380))
+ // We can use the optimized fp16 to float and float to fp16 conversion routines
+ #define EIGEN_HAS_FP16_C
+
+ #if defined(EIGEN_COMP_CLANG)
+ // Workaround for clang: The FP16C intrinsics for clang are included by
+ // immintrin.h, as opposed to emmintrin.h as suggested by Intel:
+ // https://software.intel.com/sites/landingpage/IntrinsicsGuide/#othertechs=FP16C&expand=1711
+ #include <immintrin.h>
+ #endif
+#endif
+
+#if defined EIGEN_CUDACC
+ #define EIGEN_VECTORIZE_GPU
+ #include <vector_types.h>
+ #if EIGEN_CUDA_SDK_VER >= 70500
+ #define EIGEN_HAS_CUDA_FP16
+ #endif
+#endif
+
+#if defined(EIGEN_HAS_CUDA_FP16)
+ #include <cuda_runtime_api.h>
+ #include <cuda_fp16.h>
+#endif
+
+#if defined(EIGEN_HIPCC)
+ #define EIGEN_VECTORIZE_GPU
+ #include <hip/hip_vector_types.h>
+ #define EIGEN_HAS_HIP_FP16
+ #include <hip/hip_fp16.h>
+#endif
+
+
+/** \brief Namespace containing all symbols from the %Eigen library. */
+namespace Eigen {
+
+inline static const char *SimdInstructionSetsInUse(void) {
+#if defined(EIGEN_VECTORIZE_AVX512)
+ return "AVX512, FMA, AVX2, AVX, SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
+#elif defined(EIGEN_VECTORIZE_AVX)
+ return "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
+#elif defined(EIGEN_VECTORIZE_SSE4_2)
+ return "SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2";
+#elif defined(EIGEN_VECTORIZE_SSE4_1)
+ return "SSE, SSE2, SSE3, SSSE3, SSE4.1";
+#elif defined(EIGEN_VECTORIZE_SSSE3)
+ return "SSE, SSE2, SSE3, SSSE3";
+#elif defined(EIGEN_VECTORIZE_SSE3)
+ return "SSE, SSE2, SSE3";
+#elif defined(EIGEN_VECTORIZE_SSE2)
+ return "SSE, SSE2";
+#elif defined(EIGEN_VECTORIZE_ALTIVEC)
+ return "AltiVec";
+#elif defined(EIGEN_VECTORIZE_VSX)
+ return "VSX";
+#elif defined(EIGEN_VECTORIZE_NEON)
+ return "ARM NEON";
+#elif defined(EIGEN_VECTORIZE_SVE)
+ return "ARM SVE";
+#elif defined(EIGEN_VECTORIZE_ZVECTOR)
+ return "S390X ZVECTOR";
+#elif defined(EIGEN_VECTORIZE_MSA)
+ return "MIPS MSA";
+#else
+ return "None";
+#endif
+}
+
+} // end namespace Eigen
+
+
+#endif // EIGEN_CONFIGURE_VECTORIZATION_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Constants.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Constants.h
index 7587d68..35dcaa7 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Constants.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Constants.h
@@ -3,6 +3,7 @@
//
// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2007-2009 Benoit Jacob <jacob.benoit.1@gmail.com>
+// Copyright (C) 2020, Arm Limited and Contributors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -25,6 +26,10 @@
*/
const int DynamicIndex = 0xffffff;
+/** This value means that the increment to go from one value to another in a sequence is not constant for each step.
+ */
+const int UndefinedIncr = 0xfffffe;
+
/** This value means +Infinity; it is currently used only as the p parameter to MatrixBase::lpNorm<int>().
* The value Infinity there means the L-infinity norm.
*/
@@ -152,7 +157,7 @@
/** \deprecated \ingroup flags
*
* means the first coefficient packet is guaranteed to be aligned.
- * An expression cannot has the AlignedBit without the PacketAccessBit flag.
+ * An expression cannot have the AlignedBit without the PacketAccessBit flag.
* In other words, this means we are allow to perform an aligned packet access to the first element regardless
* of the expression kind:
* \code
@@ -251,12 +256,6 @@
};
/** \ingroup enums
- * Enum used by DenseBase::corner() in Eigen2 compatibility mode. */
-// FIXME after the corner() API change, this was not needed anymore, except by AlignedBox
-// TODO: find out what to do with that. Adapt the AlignedBox API ?
-enum CornerType { TopLeft, TopRight, BottomLeft, BottomRight };
-
-/** \ingroup enums
* Enum containing possible values for the \p Direction parameter of
* Reverse, PartialReduxExpr and VectorwiseOp. */
enum DirectionType {
@@ -330,9 +329,20 @@
* Enum for specifying whether to apply or solve on the left or right. */
enum SideType {
/** Apply transformation on the left. */
- OnTheLeft = 1,
+ OnTheLeft = 1,
/** Apply transformation on the right. */
- OnTheRight = 2
+ OnTheRight = 2
+};
+
+/** \ingroup enums
+ * Enum for specifying NaN-propagation behavior, e.g. for coeff-wise min/max. */
+enum NaNPropagationOptions {
+ /** Implementation defined behavior if NaNs are present. */
+ PropagateFast = 0,
+ /** Always propagate NaNs. */
+ PropagateNaN,
+ /** Always propagate not-NaNs. */
+ PropagateNumbers
};
/* the following used to be written as:
@@ -464,6 +474,8 @@
AltiVec = 0x2,
VSX = 0x3,
NEON = 0x4,
+ MSA = 0x5,
+ SVE = 0x6,
#if defined EIGEN_VECTORIZE_SSE
Target = SSE
#elif defined EIGEN_VECTORIZE_ALTIVEC
@@ -472,6 +484,10 @@
Target = VSX
#elif defined EIGEN_VECTORIZE_NEON
Target = NEON
+#elif defined EIGEN_VECTORIZE_SVE
+ Target = SVE
+#elif defined EIGEN_VECTORIZE_MSA
+ Target = MSA
#else
Target = Generic
#endif
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/DisableStupidWarnings.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/DisableStupidWarnings.h
index ce573a8..3bec072 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/DisableStupidWarnings.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/DisableStupidWarnings.h
@@ -4,7 +4,6 @@
#ifdef _MSC_VER
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
// 4101 - unreferenced local variable
- // 4127 - conditional expression is constant
// 4181 - qualifier applied to reference type ignored
// 4211 - nonstandard extension used : redefined extern to static
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
@@ -20,7 +19,7 @@
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma warning( push )
#endif
- #pragma warning( disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
+ #pragma warning( disable : 4100 4101 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
#elif defined __INTEL_COMPILER
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
@@ -42,6 +41,17 @@
#pragma clang diagnostic push
#endif
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
+ #if __clang_major__ >= 3 && __clang_minor__ >= 5
+ #pragma clang diagnostic ignored "-Wabsolute-value"
+ #endif
+ #if __clang_major__ >= 10
+ #pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
+ #endif
+ #if ( defined(__ALTIVEC__) || defined(__VSX__) ) && __cplusplus < 201103L
+ // warning: generic selections are a C11-specific feature
+ // ignoring warnings thrown at vec_ctf in Altivec/PacketMath.h
+ #pragma clang diagnostic ignored "-Wc11-extensions"
+ #endif
#elif defined __GNUC__
@@ -57,13 +67,18 @@
#if __GNUC__>=6
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
- #if __GNUC__>=9
- #pragma GCC diagnostic ignored "-Wdeprecated-copy"
+ #if __GNUC__==7
+ // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
+ #pragma GCC diagnostic ignored "-Wattributes"
#endif
-
+ #if __GNUC__==11
+ // This warning is a false positive
+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
+ #endif
#endif
#if defined __NVCC__
+ #pragma diag_suppress boolean_controlling_expr_is_constant
// Disable the "statement is unreachable" message
#pragma diag_suppress code_is_unreachable
// Disable the "dynamic initialization in unreachable code" message
@@ -81,6 +96,15 @@
#pragma diag_suppress 2671
#pragma diag_suppress 2735
#pragma diag_suppress 2737
+ #pragma diag_suppress 2739
#endif
+#else
+// warnings already disabled:
+# ifndef EIGEN_WARNINGS_DISABLED_2
+# define EIGEN_WARNINGS_DISABLED_2
+# elif defined(EIGEN_INTERNAL_DEBUGGING)
+# error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
+# endif
+
#endif // not EIGEN_WARNINGS_DISABLED
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ForwardDeclarations.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ForwardDeclarations.h
index ea10739..2f9cc44 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ForwardDeclarations.h
@@ -47,11 +47,7 @@
template<typename Derived> struct EigenBase;
template<typename Derived> class DenseBase;
template<typename Derived> class PlainObjectBase;
-
-
-template<typename Derived,
- int Level = internal::accessors_level<Derived>::value >
-class DenseCoeffsBase;
+template<typename Derived, int Level> class DenseCoeffsBase;
template<typename _Scalar, int _Rows, int _Cols,
int _Options = AutoAlign |
@@ -83,6 +79,8 @@
template<typename ExpressionType> class SwapWrapper;
template<typename XprType, int BlockRows=Dynamic, int BlockCols=Dynamic, bool InnerPanel = false> class Block;
+template<typename XprType, typename RowIndices, typename ColIndices> class IndexedView;
+template<typename XprType, int Rows=Dynamic, int Cols=Dynamic, int Order=0> class Reshaped;
template<typename MatrixType, int Size=Dynamic> class VectorBlock;
template<typename MatrixType> class Transpose;
@@ -112,7 +110,7 @@
template<typename Derived,
int Level = internal::accessors_level<Derived>::has_write_access ? WriteAccessors : ReadOnlyAccessors
> class MapBase;
-template<int InnerStrideAtCompileTime, int OuterStrideAtCompileTime> class Stride;
+template<int OuterStrideAtCompileTime, int InnerStrideAtCompileTime> class Stride;
template<int Value = Dynamic> class InnerStride;
template<int Value = Dynamic> class OuterStride;
template<typename MatrixType, int MapOptions=Unaligned, typename StrideType = Stride<0,0> > class Map;
@@ -133,6 +131,10 @@
template<typename XprType> class InnerIterator;
namespace internal {
+template<typename XprType> class generic_randaccess_stl_iterator;
+template<typename XprType> class pointer_based_stl_iterator;
+template<typename XprType, DirectionType Direction> class subvector_stl_iterator;
+template<typename XprType, DirectionType Direction> class subvector_stl_reverse_iterator;
template<typename DecompositionType> struct kernel_retval_base;
template<typename DecompositionType> struct kernel_retval;
template<typename DecompositionType> struct image_retval_base;
@@ -178,14 +180,15 @@
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_sum_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_difference_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_conj_product_op;
-template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_min_op;
-template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_max_op;
+template<typename LhsScalar,typename RhsScalar=LhsScalar, int NaNPropagation=PropagateFast> struct scalar_min_op;
+template<typename LhsScalar,typename RhsScalar=LhsScalar, int NaNPropagation=PropagateFast> struct scalar_max_op;
template<typename Scalar> struct scalar_opposite_op;
template<typename Scalar> struct scalar_conjugate_op;
template<typename Scalar> struct scalar_real_op;
template<typename Scalar> struct scalar_imag_op;
template<typename Scalar> struct scalar_abs_op;
template<typename Scalar> struct scalar_abs2_op;
+template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_absolute_difference_op;
template<typename Scalar> struct scalar_sqrt_op;
template<typename Scalar> struct scalar_rsqrt_op;
template<typename Scalar> struct scalar_exp_op;
@@ -202,7 +205,7 @@
template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op;
-template<typename Scalar,bool iscpx> struct scalar_sign_op;
+template<typename Scalar,bool is_complex, bool is_integer> struct scalar_sign_op;
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
@@ -213,11 +216,27 @@
template<typename Scalar> struct scalar_digamma_op;
template<typename Scalar> struct scalar_erf_op;
template<typename Scalar> struct scalar_erfc_op;
+template<typename Scalar> struct scalar_ndtri_op;
template<typename Scalar> struct scalar_igamma_op;
template<typename Scalar> struct scalar_igammac_op;
template<typename Scalar> struct scalar_zeta_op;
template<typename Scalar> struct scalar_betainc_op;
+// Bessel functions in SpecialFunctions module
+template<typename Scalar> struct scalar_bessel_i0_op;
+template<typename Scalar> struct scalar_bessel_i0e_op;
+template<typename Scalar> struct scalar_bessel_i1_op;
+template<typename Scalar> struct scalar_bessel_i1e_op;
+template<typename Scalar> struct scalar_bessel_j0_op;
+template<typename Scalar> struct scalar_bessel_y0_op;
+template<typename Scalar> struct scalar_bessel_j1_op;
+template<typename Scalar> struct scalar_bessel_y1_op;
+template<typename Scalar> struct scalar_bessel_k0_op;
+template<typename Scalar> struct scalar_bessel_k0e_op;
+template<typename Scalar> struct scalar_bessel_k1_op;
+template<typename Scalar> struct scalar_bessel_k1e_op;
+
+
} // end namespace internal
struct IOFormat;
@@ -255,6 +274,7 @@
template<typename MatrixType> class ColPivHouseholderQR;
template<typename MatrixType> class FullPivHouseholderQR;
template<typename MatrixType> class CompleteOrthogonalDecomposition;
+template<typename MatrixType> class SVDBase;
template<typename MatrixType, int QRPreconditioner = ColPivHouseholderQRPreconditioner> class JacobiSVD;
template<typename MatrixType> class BDCSVD;
template<typename MatrixType, int UpLo = Lower> class LLT;
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/IndexedViewHelper.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/IndexedViewHelper.h
new file mode 100644
index 0000000..f85de30
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/IndexedViewHelper.h
@@ -0,0 +1,186 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+
+#ifndef EIGEN_INDEXED_VIEW_HELPER_H
+#define EIGEN_INDEXED_VIEW_HELPER_H
+
+namespace Eigen {
+
+namespace internal {
+struct symbolic_last_tag {};
+}
+
+/** \var last
+ * \ingroup Core_Module
+ *
+ * Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically reference the last element/row/columns
+ * of the underlying vector or matrix once passed to DenseBase::operator()(const RowIndices&, const ColIndices&).
+ *
+ * This symbolic placeholder supports standard arithmetic operations.
+ *
+ * A typical usage example would be:
+ * \code
+ * using namespace Eigen;
+ * using Eigen::last;
+ * VectorXd v(n);
+ * v(seq(2,last-2)).setOnes();
+ * \endcode
+ *
+ * \sa end
+ */
+static const symbolic::SymbolExpr<internal::symbolic_last_tag> last; // PLEASE use Eigen::last instead of Eigen::placeholders::last
+
+/** \var lastp1
+ * \ingroup Core_Module
+ *
+ * Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically
+ * reference the last+1 element/row/columns of the underlying vector or matrix once
+ * passed to DenseBase::operator()(const RowIndices&, const ColIndices&).
+ *
+ * This symbolic placeholder supports standard arithmetic operations.
+ * It is essentially an alias to last+fix<1>.
+ *
+ * \sa last
+ */
+#ifdef EIGEN_PARSED_BY_DOXYGEN
+static const auto lastp1 = last+fix<1>;
+#else
+// Using a FixedExpr<1> expression is important here to make sure the compiler
+// can fully optimize the computation starting indices with zero overhead.
+static const symbolic::AddExpr<symbolic::SymbolExpr<internal::symbolic_last_tag>,symbolic::ValueExpr<Eigen::internal::FixedInt<1> > > lastp1(last+fix<1>());
+#endif
+
+namespace internal {
+
+ // Replace symbolic last/end "keywords" by their true runtime value
+inline Index eval_expr_given_size(Index x, Index /* size */) { return x; }
+
+template<int N>
+FixedInt<N> eval_expr_given_size(FixedInt<N> x, Index /*size*/) { return x; }
+
+template<typename Derived>
+Index eval_expr_given_size(const symbolic::BaseExpr<Derived> &x, Index size)
+{
+ return x.derived().eval(last=size-1);
+}
+
+// Extract increment/step at compile time
+template<typename T, typename EnableIf = void> struct get_compile_time_incr {
+ enum { value = UndefinedIncr };
+};
+
+// Analogue of std::get<0>(x), but tailored for our needs.
+template<typename T>
+EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT { return x.first(); }
+
+// IndexedViewCompatibleType/makeIndexedViewCompatible turn an arbitrary object of type T into something usable by MatrixSlice
+// The generic implementation is a no-op
+template<typename T,int XprSize,typename EnableIf=void>
+struct IndexedViewCompatibleType {
+ typedef T type;
+};
+
+template<typename T,typename Q>
+const T& makeIndexedViewCompatible(const T& x, Index /*size*/, Q) { return x; }
+
+//--------------------------------------------------------------------------------
+// Handling of a single Index
+//--------------------------------------------------------------------------------
+
+struct SingleRange {
+ enum {
+ SizeAtCompileTime = 1
+ };
+ SingleRange(Index val) : m_value(val) {}
+ Index operator[](Index) const { return m_value; }
+ static EIGEN_CONSTEXPR Index size() EIGEN_NOEXCEPT { return 1; }
+ Index first() const EIGEN_NOEXCEPT { return m_value; }
+ Index m_value;
+};
+
+template<> struct get_compile_time_incr<SingleRange> {
+ enum { value = 1 }; // 1 or 0 ??
+};
+
+// Turn a single index into something that looks like an array (i.e., that exposes a .size(), and operator[](int) methods)
+template<typename T, int XprSize>
+struct IndexedViewCompatibleType<T,XprSize,typename internal::enable_if<internal::is_integral<T>::value>::type> {
+ // Here we could simply use Array, but maybe it's less work for the compiler to use
+ // a simpler wrapper as SingleRange
+ //typedef Eigen::Array<Index,1,1> type;
+ typedef SingleRange type;
+};
+
+template<typename T, int XprSize>
+struct IndexedViewCompatibleType<T, XprSize, typename enable_if<symbolic::is_symbolic<T>::value>::type> {
+ typedef SingleRange type;
+};
+
+
+template<typename T>
+typename enable_if<symbolic::is_symbolic<T>::value,SingleRange>::type
+makeIndexedViewCompatible(const T& id, Index size, SpecializedType) {
+ return eval_expr_given_size(id,size);
+}
+
+//--------------------------------------------------------------------------------
+// Handling of all
+//--------------------------------------------------------------------------------
+
+struct all_t { all_t() {} };
+
+// Convert a symbolic 'all' into a usable range type
+template<int XprSize>
+struct AllRange {
+ enum { SizeAtCompileTime = XprSize };
+ AllRange(Index size = XprSize) : m_size(size) {}
+ EIGEN_CONSTEXPR Index operator[](Index i) const EIGEN_NOEXCEPT { return i; }
+ EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_size.value(); }
+ EIGEN_CONSTEXPR Index first() const EIGEN_NOEXCEPT { return 0; }
+ variable_if_dynamic<Index,XprSize> m_size;
+};
+
+template<int XprSize>
+struct IndexedViewCompatibleType<all_t,XprSize> {
+ typedef AllRange<XprSize> type;
+};
+
+template<typename XprSizeType>
+inline AllRange<get_fixed_value<XprSizeType>::value> makeIndexedViewCompatible(all_t , XprSizeType size, SpecializedType) {
+ return AllRange<get_fixed_value<XprSizeType>::value>(size);
+}
+
+template<int Size> struct get_compile_time_incr<AllRange<Size> > {
+ enum { value = 1 };
+};
+
+} // end namespace internal
+
+
+/** \var all
+ * \ingroup Core_Module
+ * Can be used as a parameter to DenseBase::operator()(const RowIndices&, const ColIndices&) to index all rows or columns
+ */
+static const Eigen::internal::all_t all; // PLEASE use Eigen::all instead of Eigen::placeholders::all
+
+
+namespace placeholders {
+ typedef symbolic::SymbolExpr<internal::symbolic_last_tag> last_t;
+ typedef symbolic::AddExpr<symbolic::SymbolExpr<internal::symbolic_last_tag>,symbolic::ValueExpr<Eigen::internal::FixedInt<1> > > end_t;
+ typedef Eigen::internal::all_t all_t;
+
+ EIGEN_DEPRECATED static const all_t all = Eigen::all; // PLEASE use Eigen::all instead of Eigen::placeholders::all
+ EIGEN_DEPRECATED static const last_t last = Eigen::last; // PLEASE use Eigen::last instead of Eigen::placeholders::last
+ EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end
+}
+
+} // end namespace Eigen
+
+#endif // EIGEN_INDEXED_VIEW_HELPER_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/IntegralConstant.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/IntegralConstant.h
new file mode 100644
index 0000000..945d426
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/IntegralConstant.h
@@ -0,0 +1,272 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+
+#ifndef EIGEN_INTEGRAL_CONSTANT_H
+#define EIGEN_INTEGRAL_CONSTANT_H
+
+namespace Eigen {
+
+namespace internal {
+
+template<int N> class FixedInt;
+template<int N> class VariableAndFixedInt;
+
+/** \internal
+ * \class FixedInt
+ *
+ * This class embeds a compile-time integer \c N.
+ *
+ * It is similar to c++11 std::integral_constant<int,N> but with some additional features
+ * such as:
+ * - implicit conversion to int
+ * - arithmetic and some bitwise operators: -, +, *, /, %, &, |
+ * - c++98/14 compatibility with fix<N> and fix<N>() syntax to define integral constants.
+ *
+ * It is strongly discouraged to directly deal with this class FixedInt. Instances are expcected to
+ * be created by the user using Eigen::fix<N> or Eigen::fix<N>(). In C++98-11, the former syntax does
+ * not create a FixedInt<N> instance but rather a point to function that needs to be \em cleaned-up
+ * using the generic helper:
+ * \code
+ * internal::cleanup_index_type<T>::type
+ * internal::cleanup_index_type<T,DynamicKey>::type
+ * \endcode
+ * where T can a FixedInt<N>, a pointer to function FixedInt<N> (*)(), or numerous other integer-like representations.
+ * \c DynamicKey is either Dynamic (default) or DynamicIndex and used to identify true compile-time values.
+ *
+ * For convenience, you can extract the compile-time value \c N in a generic way using the following helper:
+ * \code
+ * internal::get_fixed_value<T,DefaultVal>::value
+ * \endcode
+ * that will give you \c N if T equals FixedInt<N> or FixedInt<N> (*)(), and \c DefaultVal if T does not embed any compile-time value (e.g., T==int).
+ *
+ * \sa fix<N>, class VariableAndFixedInt
+ */
+template<int N> class FixedInt
+{
+public:
+ static const int value = N;
+ EIGEN_CONSTEXPR operator int() const { return value; }
+ FixedInt() {}
+ FixedInt( VariableAndFixedInt<N> other) {
+ #ifndef EIGEN_INTERNAL_DEBUGGING
+ EIGEN_UNUSED_VARIABLE(other);
+ #endif
+ eigen_internal_assert(int(other)==N);
+ }
+
+ FixedInt<-N> operator-() const { return FixedInt<-N>(); }
+ template<int M>
+ FixedInt<N+M> operator+( FixedInt<M>) const { return FixedInt<N+M>(); }
+ template<int M>
+ FixedInt<N-M> operator-( FixedInt<M>) const { return FixedInt<N-M>(); }
+ template<int M>
+ FixedInt<N*M> operator*( FixedInt<M>) const { return FixedInt<N*M>(); }
+ template<int M>
+ FixedInt<N/M> operator/( FixedInt<M>) const { return FixedInt<N/M>(); }
+ template<int M>
+ FixedInt<N%M> operator%( FixedInt<M>) const { return FixedInt<N%M>(); }
+ template<int M>
+ FixedInt<N|M> operator|( FixedInt<M>) const { return FixedInt<N|M>(); }
+ template<int M>
+ FixedInt<N&M> operator&( FixedInt<M>) const { return FixedInt<N&M>(); }
+
+#if EIGEN_HAS_CXX14_VARIABLE_TEMPLATES
+ // Needed in C++14 to allow fix<N>():
+ FixedInt operator() () const { return *this; }
+
+ VariableAndFixedInt<N> operator() (int val) const { return VariableAndFixedInt<N>(val); }
+#else
+ FixedInt ( FixedInt<N> (*)() ) {}
+#endif
+
+#if EIGEN_HAS_CXX11
+ FixedInt(std::integral_constant<int,N>) {}
+#endif
+};
+
+/** \internal
+ * \class VariableAndFixedInt
+ *
+ * This class embeds both a compile-time integer \c N and a runtime integer.
+ * Both values are supposed to be equal unless the compile-time value \c N has a special
+ * value meaning that the runtime-value should be used. Depending on the context, this special
+ * value can be either Eigen::Dynamic (for positive quantities) or Eigen::DynamicIndex (for
+ * quantities that can be negative).
+ *
+ * It is the return-type of the function Eigen::fix<N>(int), and most of the time this is the only
+ * way it is used. It is strongly discouraged to directly deal with instances of VariableAndFixedInt.
+ * Indeed, in order to write generic code, it is the responsibility of the callee to properly convert
+ * it to either a true compile-time quantity (i.e. a FixedInt<N>), or to a runtime quantity (e.g., an Index)
+ * using the following generic helper:
+ * \code
+ * internal::cleanup_index_type<T>::type
+ * internal::cleanup_index_type<T,DynamicKey>::type
+ * \endcode
+ * where T can be a template instantiation of VariableAndFixedInt or numerous other integer-like representations.
+ * \c DynamicKey is either Dynamic (default) or DynamicIndex and used to identify true compile-time values.
+ *
+ * For convenience, you can also extract the compile-time value \c N using the following helper:
+ * \code
+ * internal::get_fixed_value<T,DefaultVal>::value
+ * \endcode
+ * that will give you \c N if T equals VariableAndFixedInt<N>, and \c DefaultVal if T does not embed any compile-time value (e.g., T==int).
+ *
+ * \sa fix<N>(int), class FixedInt
+ */
+template<int N> class VariableAndFixedInt
+{
+public:
+ static const int value = N;
+ operator int() const { return m_value; }
+ VariableAndFixedInt(int val) { m_value = val; }
+protected:
+ int m_value;
+};
+
+template<typename T, int Default=Dynamic> struct get_fixed_value {
+ static const int value = Default;
+};
+
+template<int N,int Default> struct get_fixed_value<FixedInt<N>,Default> {
+ static const int value = N;
+};
+
+#if !EIGEN_HAS_CXX14
+template<int N,int Default> struct get_fixed_value<FixedInt<N> (*)(),Default> {
+ static const int value = N;
+};
+#endif
+
+template<int N,int Default> struct get_fixed_value<VariableAndFixedInt<N>,Default> {
+ static const int value = N ;
+};
+
+template<typename T, int N, int Default>
+struct get_fixed_value<variable_if_dynamic<T,N>,Default> {
+ static const int value = N;
+};
+
+template<typename T> EIGEN_DEVICE_FUNC Index get_runtime_value(const T &x) { return x; }
+#if !EIGEN_HAS_CXX14
+template<int N> EIGEN_DEVICE_FUNC Index get_runtime_value(FixedInt<N> (*)()) { return N; }
+#endif
+
+// Cleanup integer/FixedInt/VariableAndFixedInt/etc types:
+
+// By default, no cleanup:
+template<typename T, int DynamicKey=Dynamic, typename EnableIf=void> struct cleanup_index_type { typedef T type; };
+
+// Convert any integral type (e.g., short, int, unsigned int, etc.) to Eigen::Index
+template<typename T, int DynamicKey> struct cleanup_index_type<T,DynamicKey,typename internal::enable_if<internal::is_integral<T>::value>::type> { typedef Index type; };
+
+#if !EIGEN_HAS_CXX14
+// In c++98/c++11, fix<N> is a pointer to function that we better cleanup to a true FixedInt<N>:
+template<int N, int DynamicKey> struct cleanup_index_type<FixedInt<N> (*)(), DynamicKey> { typedef FixedInt<N> type; };
+#endif
+
+// If VariableAndFixedInt does not match DynamicKey, then we turn it to a pure compile-time value:
+template<int N, int DynamicKey> struct cleanup_index_type<VariableAndFixedInt<N>, DynamicKey> { typedef FixedInt<N> type; };
+// If VariableAndFixedInt matches DynamicKey, then we turn it to a pure runtime-value (aka Index):
+template<int DynamicKey> struct cleanup_index_type<VariableAndFixedInt<DynamicKey>, DynamicKey> { typedef Index type; };
+
+#if EIGEN_HAS_CXX11
+template<int N, int DynamicKey> struct cleanup_index_type<std::integral_constant<int,N>, DynamicKey> { typedef FixedInt<N> type; };
+#endif
+
+} // end namespace internal
+
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+
+#if EIGEN_HAS_CXX14_VARIABLE_TEMPLATES
+template<int N>
+static const internal::FixedInt<N> fix{};
+#else
+template<int N>
+inline internal::FixedInt<N> fix() { return internal::FixedInt<N>(); }
+
+// The generic typename T is mandatory. Otherwise, a code like fix<N> could refer to either the function above or this next overload.
+// This way a code like fix<N> can only refer to the previous function.
+template<int N,typename T>
+inline internal::VariableAndFixedInt<N> fix(T val) { return internal::VariableAndFixedInt<N>(internal::convert_index<int>(val)); }
+#endif
+
+#else // EIGEN_PARSED_BY_DOXYGEN
+
+/** \var fix<N>()
+ * \ingroup Core_Module
+ *
+ * This \em identifier permits to construct an object embedding a compile-time integer \c N.
+ *
+ * \tparam N the compile-time integer value
+ *
+ * It is typically used in conjunction with the Eigen::seq and Eigen::seqN functions to pass compile-time values to them:
+ * \code
+ * seqN(10,fix<4>,fix<-3>) // <=> [10 7 4 1]
+ * \endcode
+ *
+ * See also the function fix(int) to pass both a compile-time and runtime value.
+ *
+ * In c++14, it is implemented as:
+ * \code
+ * template<int N> static const internal::FixedInt<N> fix{};
+ * \endcode
+ * where internal::FixedInt<N> is an internal template class similar to
+ * <a href="http://en.cppreference.com/w/cpp/types/integral_constant">\c std::integral_constant </a><tt> <int,N> </tt>
+ * Here, \c fix<N> is thus an object of type \c internal::FixedInt<N>.
+ *
+ * In c++98/11, it is implemented as a function:
+ * \code
+ * template<int N> inline internal::FixedInt<N> fix();
+ * \endcode
+ * Here internal::FixedInt<N> is thus a pointer to function.
+ *
+ * If for some reason you want a true object in c++98 then you can write: \code fix<N>() \endcode which is also valid in c++14.
+ *
+ * \sa fix<N>(int), seq, seqN
+ */
+template<int N>
+static const auto fix();
+
+/** \fn fix<N>(int)
+ * \ingroup Core_Module
+ *
+ * This function returns an object embedding both a compile-time integer \c N, and a fallback runtime value \a val.
+ *
+ * \tparam N the compile-time integer value
+ * \param val the fallback runtime integer value
+ *
+ * This function is a more general version of the \ref fix identifier/function that can be used in template code
+ * where the compile-time value could turn out to actually mean "undefined at compile-time". For positive integers
+ * such as a size or a dimension, this case is identified by Eigen::Dynamic, whereas runtime signed integers
+ * (e.g., an increment/stride) are identified as Eigen::DynamicIndex. In such a case, the runtime value \a val
+ * will be used as a fallback.
+ *
+ * A typical use case would be:
+ * \code
+ * template<typename Derived> void foo(const MatrixBase<Derived> &mat) {
+ * const int N = Derived::RowsAtCompileTime==Dynamic ? Dynamic : Derived::RowsAtCompileTime/2;
+ * const int n = mat.rows()/2;
+ * ... mat( seqN(0,fix<N>(n) ) ...;
+ * }
+ * \endcode
+ * In this example, the function Eigen::seqN knows that the second argument is expected to be a size.
+ * If the passed compile-time value N equals Eigen::Dynamic, then the proxy object returned by fix will be dissmissed, and converted to an Eigen::Index of value \c n.
+ * Otherwise, the runtime-value \c n will be dissmissed, and the returned ArithmeticSequence will be of the exact same type as <tt> seqN(0,fix<N>) </tt>.
+ *
+ * \sa fix, seqN, class ArithmeticSequence
+ */
+template<int N>
+static const auto fix(int val);
+
+#endif // EIGEN_PARSED_BY_DOXYGEN
+
+} // end namespace Eigen
+
+#endif // EIGEN_INTEGRAL_CONSTANT_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Macros.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Macros.h
index b2bb0c2..986c3d4 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Macros.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Macros.h
@@ -11,19 +11,56 @@
#ifndef EIGEN_MACROS_H
#define EIGEN_MACROS_H
+//------------------------------------------------------------------------------------------
+// Eigen version and basic defaults
+//------------------------------------------------------------------------------------------
+
#define EIGEN_WORLD_VERSION 3
-#define EIGEN_MAJOR_VERSION 3
-#define EIGEN_MINOR_VERSION 7
+#define EIGEN_MAJOR_VERSION 4
+#define EIGEN_MINOR_VERSION 0
#define EIGEN_VERSION_AT_LEAST(x,y,z) (EIGEN_WORLD_VERSION>x || (EIGEN_WORLD_VERSION>=x && \
(EIGEN_MAJOR_VERSION>y || (EIGEN_MAJOR_VERSION>=y && \
EIGEN_MINOR_VERSION>=z))))
+#ifdef EIGEN_DEFAULT_TO_ROW_MAJOR
+#define EIGEN_DEFAULT_MATRIX_STORAGE_ORDER_OPTION Eigen::RowMajor
+#else
+#define EIGEN_DEFAULT_MATRIX_STORAGE_ORDER_OPTION Eigen::ColMajor
+#endif
+
+#ifndef EIGEN_DEFAULT_DENSE_INDEX_TYPE
+#define EIGEN_DEFAULT_DENSE_INDEX_TYPE std::ptrdiff_t
+#endif
+
+// Upperbound on the C++ version to use.
+// Expected values are 03, 11, 14, 17, etc.
+// By default, let's use an arbitrarily large C++ version.
+#ifndef EIGEN_MAX_CPP_VER
+#define EIGEN_MAX_CPP_VER 99
+#endif
+
+/** Allows to disable some optimizations which might affect the accuracy of the result.
+ * Such optimization are enabled by default, and set EIGEN_FAST_MATH to 0 to disable them.
+ * They currently include:
+ * - single precision ArrayBase::sin() and ArrayBase::cos() for SSE and AVX vectorization.
+ */
+#ifndef EIGEN_FAST_MATH
+#define EIGEN_FAST_MATH 1
+#endif
+
+#ifndef EIGEN_STACK_ALLOCATION_LIMIT
+// 131072 == 128 KB
+#define EIGEN_STACK_ALLOCATION_LIMIT 131072
+#endif
+
+//------------------------------------------------------------------------------------------
// Compiler identification, EIGEN_COMP_*
+//------------------------------------------------------------------------------------------
/// \internal EIGEN_COMP_GNUC set to 1 for all compilers compatible with GCC
#ifdef __GNUC__
- #define EIGEN_COMP_GNUC 1
+ #define EIGEN_COMP_GNUC (__GNUC__*10+__GNUC_MINOR__)
#else
#define EIGEN_COMP_GNUC 0
#endif
@@ -35,6 +72,12 @@
#define EIGEN_COMP_CLANG 0
#endif
+/// \internal EIGEN_COMP_CASTXML set to 1 if being preprocessed by CastXML
+#if defined(__castxml__)
+ #define EIGEN_COMP_CASTXML 1
+#else
+ #define EIGEN_COMP_CASTXML 0
+#endif
/// \internal EIGEN_COMP_LLVM set to 1 if the compiler backend is llvm
#if defined(__llvm__)
@@ -71,14 +114,44 @@
#define EIGEN_COMP_MSVC 0
#endif
+#if defined(__NVCC__)
+#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9)
+ #define EIGEN_COMP_NVCC ((__CUDACC_VER_MAJOR__ * 10000) + (__CUDACC_VER_MINOR__ * 100))
+#elif defined(__CUDACC_VER__)
+ #define EIGEN_COMP_NVCC __CUDACC_VER__
+#else
+ #error "NVCC did not define compiler version."
+#endif
+#else
+ #define EIGEN_COMP_NVCC 0
+#endif
+
// For the record, here is a table summarizing the possible values for EIGEN_COMP_MSVC:
-// name ver MSC_VER
-// 2008 9 1500
-// 2010 10 1600
-// 2012 11 1700
-// 2013 12 1800
-// 2015 14 1900
-// "15" 15 1900
+// name ver MSC_VER
+// 2008 9 1500
+// 2010 10 1600
+// 2012 11 1700
+// 2013 12 1800
+// 2015 14 1900
+// "15" 15 1900
+// 2017-14.1 15.0 1910
+// 2017-14.11 15.3 1911
+// 2017-14.12 15.5 1912
+// 2017-14.13 15.6 1913
+// 2017-14.14 15.7 1914
+
+/// \internal EIGEN_COMP_MSVC_LANG set to _MSVC_LANG if the compiler is Microsoft Visual C++, 0 otherwise.
+#if defined(_MSVC_LANG)
+ #define EIGEN_COMP_MSVC_LANG _MSVC_LANG
+#else
+ #define EIGEN_COMP_MSVC_LANG 0
+#endif
+
+// For the record, here is a table summarizing the possible values for EIGEN_COMP_MSVC_LANG:
+// MSVC option Standard MSVC_LANG
+// /std:c++14 (default as of VS 2019) C++14 201402L
+// /std:c++17 C++17 201703L
+// /std:c++latest >C++17 >201703L
/// \internal EIGEN_COMP_MSVC_STRICT set to 1 if the compiler is really Microsoft Visual C++ and not ,e.g., ICC or clang-cl
#if EIGEN_COMP_MSVC && !(EIGEN_COMP_ICC || EIGEN_COMP_LLVM || EIGEN_COMP_CLANG)
@@ -87,16 +160,21 @@
#define EIGEN_COMP_MSVC_STRICT 0
#endif
-/// \internal EIGEN_COMP_IBM set to 1 if the compiler is IBM XL C++
-#if defined(__IBMCPP__) || defined(__xlc__)
- #define EIGEN_COMP_IBM 1
+/// \internal EIGEN_COMP_IBM set to xlc version if the compiler is IBM XL C++
+// XLC version
+// 3.1 0x0301
+// 4.5 0x0405
+// 5.0 0x0500
+// 12.1 0x0C01
+#if defined(__IBMCPP__) || defined(__xlc__) || defined(__ibmxl__)
+ #define EIGEN_COMP_IBM __xlC__
#else
#define EIGEN_COMP_IBM 0
#endif
-/// \internal EIGEN_COMP_PGI set to 1 if the compiler is Portland Group Compiler
+/// \internal EIGEN_COMP_PGI set to PGI version if the compiler is Portland Group Compiler
#if defined(__PGI)
- #define EIGEN_COMP_PGI 1
+ #define EIGEN_COMP_PGI (__PGIC__*100+__PGIC_MINOR__)
#else
#define EIGEN_COMP_PGI 0
#endif
@@ -108,7 +186,7 @@
#define EIGEN_COMP_ARM 0
#endif
-/// \internal EIGEN_COMP_ARM set to 1 if the compiler is ARM Compiler
+/// \internal EIGEN_COMP_EMSCRIPTEN set to 1 if the compiler is Emscripten Compiler
#if defined(__EMSCRIPTEN__)
#define EIGEN_COMP_EMSCRIPTEN 1
#else
@@ -142,9 +220,13 @@
#endif
-// Architecture identification, EIGEN_ARCH_*
-#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
+//------------------------------------------------------------------------------------------
+// Architecture identification, EIGEN_ARCH_*
+//------------------------------------------------------------------------------------------
+
+
+#if defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) || defined(__amd64)
#define EIGEN_ARCH_x86_64 1
#else
#define EIGEN_ARCH_x86_64 0
@@ -170,18 +252,61 @@
#endif
/// \internal EIGEN_ARCH_ARM64 set to 1 if the architecture is ARM64
-#if defined(__aarch64__)
+#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
#define EIGEN_ARCH_ARM64 1
#else
#define EIGEN_ARCH_ARM64 0
#endif
+/// \internal EIGEN_ARCH_ARM_OR_ARM64 set to 1 if the architecture is ARM or ARM64
#if EIGEN_ARCH_ARM || EIGEN_ARCH_ARM64
#define EIGEN_ARCH_ARM_OR_ARM64 1
#else
#define EIGEN_ARCH_ARM_OR_ARM64 0
#endif
+/// \internal EIGEN_ARCH_ARMV8 set to 1 if the architecture is armv8 or greater.
+#if EIGEN_ARCH_ARM_OR_ARM64 && defined(__ARM_ARCH) && __ARM_ARCH >= 8
+#define EIGEN_ARCH_ARMV8 1
+#else
+#define EIGEN_ARCH_ARMV8 0
+#endif
+
+
+/// \internal EIGEN_HAS_ARM64_FP16 set to 1 if the architecture provides an IEEE
+/// compliant Arm fp16 type
+#if EIGEN_ARCH_ARM64
+ #ifndef EIGEN_HAS_ARM64_FP16
+ #if defined(__ARM_FP16_FORMAT_IEEE)
+ #define EIGEN_HAS_ARM64_FP16 1
+ #else
+ #define EIGEN_HAS_ARM64_FP16 0
+ #endif
+ #endif
+#endif
+
+/// \internal EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC set to 1 if the architecture
+/// supports Neon vector intrinsics for fp16.
+#if EIGEN_ARCH_ARM64
+ #ifndef EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ #define EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC 1
+ #else
+ #define EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC 0
+ #endif
+ #endif
+#endif
+
+/// \internal EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC set to 1 if the architecture
+/// supports Neon scalar intrinsics for fp16.
+#if EIGEN_ARCH_ARM64
+ #ifndef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
+ #if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
+ #define EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC 1
+ #endif
+ #endif
+#endif
+
/// \internal EIGEN_ARCH_MIPS set to 1 if the architecture is MIPS
#if defined(__mips__) || defined(__mips)
#define EIGEN_ARCH_MIPS 1
@@ -212,7 +337,9 @@
+//------------------------------------------------------------------------------------------
// Operating system identification, EIGEN_OS_*
+//------------------------------------------------------------------------------------------
/// \internal EIGEN_OS_UNIX set to 1 if the OS is a unix variant
#if defined(__unix__) || defined(__unix)
@@ -299,9 +426,17 @@
#define EIGEN_OS_WIN_STRICT 0
#endif
-/// \internal EIGEN_OS_SUN set to 1 if the OS is SUN
+/// \internal EIGEN_OS_SUN set to __SUNPRO_C if the OS is SUN
+// compiler solaris __SUNPRO_C
+// version studio
+// 5.7 10 0x570
+// 5.8 11 0x580
+// 5.9 12 0x590
+// 5.10 12.1 0x5100
+// 5.11 12.2 0x5110
+// 5.12 12.3 0x5120
#if (defined(sun) || defined(__sun)) && !(defined(__SVR4) || defined(__svr4__))
- #define EIGEN_OS_SUN 1
+ #define EIGEN_OS_SUN __SUNPRO_C
#else
#define EIGEN_OS_SUN 0
#endif
@@ -314,6 +449,131 @@
#endif
+//------------------------------------------------------------------------------------------
+// Detect GPU compilers and architectures
+//------------------------------------------------------------------------------------------
+
+// NVCC is not supported as the target platform for HIPCC
+// Note that this also makes EIGEN_CUDACC and EIGEN_HIPCC mutually exclusive
+#if defined(__NVCC__) && defined(__HIPCC__)
+ #error "NVCC as the target platform for HIPCC is currently not supported."
+#endif
+
+#if defined(__CUDACC__) && !defined(EIGEN_NO_CUDA)
+ // Means the compiler is either nvcc or clang with CUDA enabled
+ #define EIGEN_CUDACC __CUDACC__
+#endif
+
+#if defined(__CUDA_ARCH__) && !defined(EIGEN_NO_CUDA)
+ // Means we are generating code for the device
+ #define EIGEN_CUDA_ARCH __CUDA_ARCH__
+#endif
+
+#if defined(EIGEN_CUDACC)
+#include <cuda.h>
+ #define EIGEN_CUDA_SDK_VER (CUDA_VERSION * 10)
+#else
+ #define EIGEN_CUDA_SDK_VER 0
+#endif
+
+#if defined(__HIPCC__) && !defined(EIGEN_NO_HIP)
+ // Means the compiler is HIPCC (analogous to EIGEN_CUDACC, but for HIP)
+ #define EIGEN_HIPCC __HIPCC__
+
+ // We need to include hip_runtime.h here because it pulls in
+ // ++ hip_common.h which contains the define for __HIP_DEVICE_COMPILE__
+ // ++ host_defines.h which contains the defines for the __host__ and __device__ macros
+ #include <hip/hip_runtime.h>
+
+ #if defined(__HIP_DEVICE_COMPILE__)
+ // analogous to EIGEN_CUDA_ARCH, but for HIP
+ #define EIGEN_HIP_DEVICE_COMPILE __HIP_DEVICE_COMPILE__
+ #endif
+
+ // For HIP (ROCm 3.5 and higher), we need to explicitly set the launch_bounds attribute
+ // value to 1024. The compiler assigns a default value of 256 when the attribute is not
+ // specified. This results in failures on the HIP platform, for cases when a GPU kernel
+ // without an explicit launch_bounds attribute is called with a threads_per_block value
+ // greater than 256.
+ //
+ // This is a regression in functioanlity and is expected to be fixed within the next
+ // couple of ROCm releases (compiler will go back to using 1024 value as the default)
+ //
+ // In the meantime, we will use a "only enabled for HIP" macro to set the launch_bounds
+ // attribute.
+
+ #define EIGEN_HIP_LAUNCH_BOUNDS_1024 __launch_bounds__(1024)
+
+#endif
+
+#if !defined(EIGEN_HIP_LAUNCH_BOUNDS_1024)
+#define EIGEN_HIP_LAUNCH_BOUNDS_1024
+#endif // !defined(EIGEN_HIP_LAUNCH_BOUNDS_1024)
+
+// Unify CUDA/HIPCC
+
+#if defined(EIGEN_CUDACC) || defined(EIGEN_HIPCC)
+//
+// If either EIGEN_CUDACC or EIGEN_HIPCC is defined, then define EIGEN_GPUCC
+//
+#define EIGEN_GPUCC
+//
+// EIGEN_HIPCC implies the HIP compiler and is used to tweak Eigen code for use in HIP kernels
+// EIGEN_CUDACC implies the CUDA compiler and is used to tweak Eigen code for use in CUDA kernels
+//
+// In most cases the same tweaks are required to the Eigen code to enable in both the HIP and CUDA kernels.
+// For those cases, the corresponding code should be guarded with
+// #if defined(EIGEN_GPUCC)
+// instead of
+// #if defined(EIGEN_CUDACC) || defined(EIGEN_HIPCC)
+//
+// For cases where the tweak is specific to HIP, the code should be guarded with
+// #if defined(EIGEN_HIPCC)
+//
+// For cases where the tweak is specific to CUDA, the code should be guarded with
+// #if defined(EIGEN_CUDACC)
+//
+#endif
+
+#if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIP_DEVICE_COMPILE)
+//
+// If either EIGEN_CUDA_ARCH or EIGEN_HIP_DEVICE_COMPILE is defined, then define EIGEN_GPU_COMPILE_PHASE
+//
+#define EIGEN_GPU_COMPILE_PHASE
+//
+// GPU compilers (HIPCC, NVCC) typically do two passes over the source code,
+// + one to compile the source for the "host" (ie CPU)
+// + another to compile the source for the "device" (ie. GPU)
+//
+// Code that needs to enabled only during the either the "host" or "device" compilation phase
+// needs to be guarded with a macro that indicates the current compilation phase
+//
+// EIGEN_HIP_DEVICE_COMPILE implies the device compilation phase in HIP
+// EIGEN_CUDA_ARCH implies the device compilation phase in CUDA
+//
+// In most cases, the "host" / "device" specific code is the same for both HIP and CUDA
+// For those cases, the code should be guarded with
+// #if defined(EIGEN_GPU_COMPILE_PHASE)
+// instead of
+// #if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIP_DEVICE_COMPILE)
+//
+// For cases where the tweak is specific to HIP, the code should be guarded with
+// #if defined(EIGEN_HIP_DEVICE_COMPILE)
+//
+// For cases where the tweak is specific to CUDA, the code should be guarded with
+// #if defined(EIGEN_CUDA_ARCH)
+//
+#endif
+
+#if defined(EIGEN_USE_SYCL) && defined(__SYCL_DEVICE_ONLY__)
+// EIGEN_USE_SYCL is a user-defined macro while __SYCL_DEVICE_ONLY__ is a compiler-defined macro.
+// In most cases we want to check if both macros are defined which can be done using the define below.
+#define SYCL_DEVICE_ONLY
+#endif
+
+//------------------------------------------------------------------------------------------
+// Detect Compiler/Architecture/OS specific features
+//------------------------------------------------------------------------------------------
#if EIGEN_GNUC_AT_MOST(4,3) && !EIGEN_COMP_CLANG
// see bug 89
@@ -322,20 +582,6 @@
#define EIGEN_SAFE_TO_USE_STANDARD_ASSERT_MACRO 1
#endif
-// This macro can be used to prevent from macro expansion, e.g.:
-// std::max EIGEN_NOT_A_MACRO(a,b)
-#define EIGEN_NOT_A_MACRO
-
-#ifdef EIGEN_DEFAULT_TO_ROW_MAJOR
-#define EIGEN_DEFAULT_MATRIX_STORAGE_ORDER_OPTION Eigen::RowMajor
-#else
-#define EIGEN_DEFAULT_MATRIX_STORAGE_ORDER_OPTION Eigen::ColMajor
-#endif
-
-#ifndef EIGEN_DEFAULT_DENSE_INDEX_TYPE
-#define EIGEN_DEFAULT_DENSE_INDEX_TYPE std::ptrdiff_t
-#endif
-
// Cross compiler wrapper around LLVM's __has_builtin
#ifdef __has_builtin
# define EIGEN_HAS_BUILTIN(x) __has_builtin(x)
@@ -349,26 +595,79 @@
# define __has_feature(x) 0
#endif
-// Upperbound on the C++ version to use.
-// Expected values are 03, 11, 14, 17, etc.
-// By default, let's use an arbitrarily large C++ version.
-#ifndef EIGEN_MAX_CPP_VER
-#define EIGEN_MAX_CPP_VER 99
+// Some old compilers do not support template specializations like:
+// template<typename T,int N> void foo(const T x[N]);
+#if !( EIGEN_COMP_CLANG && ( (EIGEN_COMP_CLANG<309) \
+ || (defined(__apple_build_version__) && (__apple_build_version__ < 9000000))) \
+ || EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<49)
+#define EIGEN_HAS_STATIC_ARRAY_TEMPLATE 1
+#else
+#define EIGEN_HAS_STATIC_ARRAY_TEMPLATE 0
#endif
-#if EIGEN_MAX_CPP_VER>=11 && (defined(__cplusplus) && (__cplusplus >= 201103L) || EIGEN_COMP_MSVC >= 1900)
+// The macro EIGEN_CPLUSPLUS is a replacement for __cplusplus/_MSVC_LANG that
+// works for both platforms, indicating the C++ standard version number.
+//
+// With MSVC, without defining /Zc:__cplusplus, the __cplusplus macro will
+// report 199711L regardless of the language standard specified via /std.
+// We need to rely on _MSVC_LANG instead, which is only available after
+// VS2015.3.
+#if EIGEN_COMP_MSVC_LANG > 0
+#define EIGEN_CPLUSPLUS EIGEN_COMP_MSVC_LANG
+#elif EIGEN_COMP_MSVC >= 1900
+#define EIGEN_CPLUSPLUS 201103L
+#elif defined(__cplusplus)
+#define EIGEN_CPLUSPLUS __cplusplus
+#else
+#define EIGEN_CPLUSPLUS 0
+#endif
+
+// The macro EIGEN_COMP_CXXVER defines the c++ verson expected by the compiler.
+// For instance, if compiling with gcc and -std=c++17, then EIGEN_COMP_CXXVER
+// is defined to 17.
+#if EIGEN_CPLUSPLUS > 201703L
+ #define EIGEN_COMP_CXXVER 20
+#elif EIGEN_CPLUSPLUS > 201402L
+ #define EIGEN_COMP_CXXVER 17
+#elif EIGEN_CPLUSPLUS > 201103L
+ #define EIGEN_COMP_CXXVER 14
+#elif EIGEN_CPLUSPLUS >= 201103L
+ #define EIGEN_COMP_CXXVER 11
+#else
+ #define EIGEN_COMP_CXXVER 03
+#endif
+
+#ifndef EIGEN_HAS_CXX14_VARIABLE_TEMPLATES
+ #if defined(__cpp_variable_templates) && __cpp_variable_templates >= 201304 && EIGEN_MAX_CPP_VER>=14
+ #define EIGEN_HAS_CXX14_VARIABLE_TEMPLATES 1
+ #else
+ #define EIGEN_HAS_CXX14_VARIABLE_TEMPLATES 0
+ #endif
+#endif
+
+
+// The macros EIGEN_HAS_CXX?? defines a rough estimate of available c++ features
+// but in practice we should not rely on them but rather on the availabilty of
+// individual features as defined later.
+// This is why there is no EIGEN_HAS_CXX17.
+// FIXME: get rid of EIGEN_HAS_CXX14 and maybe even EIGEN_HAS_CXX11.
+#if EIGEN_MAX_CPP_VER>=11 && EIGEN_COMP_CXXVER>=11
#define EIGEN_HAS_CXX11 1
#else
#define EIGEN_HAS_CXX11 0
#endif
+#if EIGEN_MAX_CPP_VER>=14 && EIGEN_COMP_CXXVER>=14
+#define EIGEN_HAS_CXX14 1
+#else
+#define EIGEN_HAS_CXX14 0
+#endif
// Do we support r-value references?
#ifndef EIGEN_HAS_RVALUE_REFERENCES
#if EIGEN_MAX_CPP_VER>=11 && \
(__has_feature(cxx_rvalue_references) || \
- (defined(__cplusplus) && __cplusplus >= 201103L) || \
- (EIGEN_COMP_MSVC >= 1600))
+ (EIGEN_COMP_CXXVER >= 11) || (EIGEN_COMP_MSVC >= 1600))
#define EIGEN_HAS_RVALUE_REFERENCES 1
#else
#define EIGEN_HAS_RVALUE_REFERENCES 0
@@ -376,11 +675,14 @@
#endif
// Does the compiler support C99?
+// Need to include <cmath> to make sure _GLIBCXX_USE_C99 gets defined
+#include <cmath>
#ifndef EIGEN_HAS_C99_MATH
#if EIGEN_MAX_CPP_VER>=11 && \
((defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901)) \
|| (defined(__GNUC__) && defined(_GLIBCXX_USE_C99)) \
- || (defined(_LIBCPP_VERSION) && !defined(_MSC_VER)))
+ || (defined(_LIBCPP_VERSION) && !defined(_MSC_VER)) \
+ || (EIGEN_COMP_MSVC >= 1900) || defined(SYCL_DEVICE_ONLY))
#define EIGEN_HAS_C99_MATH 1
#else
#define EIGEN_HAS_C99_MATH 0
@@ -388,21 +690,73 @@
#endif
// Does the compiler support result_of?
+// result_of was deprecated in c++17 and removed in c++ 20
#ifndef EIGEN_HAS_STD_RESULT_OF
-#if EIGEN_MAX_CPP_VER>=11 && ((__has_feature(cxx_lambdas) || (defined(__cplusplus) && __cplusplus >= 201103L)))
+#if EIGEN_HAS_CXX11 && EIGEN_COMP_CXXVER < 17
#define EIGEN_HAS_STD_RESULT_OF 1
#else
#define EIGEN_HAS_STD_RESULT_OF 0
#endif
#endif
+// Does the compiler support std::hash?
+#ifndef EIGEN_HAS_STD_HASH
+// The std::hash struct is defined in C++11 but is not labelled as a __device__
+// function and is not constexpr, so cannot be used on device.
+#if EIGEN_HAS_CXX11 && !defined(EIGEN_GPU_COMPILE_PHASE)
+#define EIGEN_HAS_STD_HASH 1
+#else
+#define EIGEN_HAS_STD_HASH 0
+#endif
+#endif // EIGEN_HAS_STD_HASH
+
+#ifndef EIGEN_HAS_STD_INVOKE_RESULT
+#if EIGEN_MAX_CPP_VER >= 17 && EIGEN_COMP_CXXVER >= 17
+#define EIGEN_HAS_STD_INVOKE_RESULT 1
+#else
+#define EIGEN_HAS_STD_INVOKE_RESULT 0
+#endif
+#endif
+
+#ifndef EIGEN_HAS_ALIGNAS
+#if EIGEN_MAX_CPP_VER>=11 && EIGEN_HAS_CXX11 && \
+ ( __has_feature(cxx_alignas) \
+ || EIGEN_HAS_CXX14 \
+ || (EIGEN_COMP_MSVC >= 1800) \
+ || (EIGEN_GNUC_AT_LEAST(4,8)) \
+ || (EIGEN_COMP_CLANG>=305) \
+ || (EIGEN_COMP_ICC>=1500) \
+ || (EIGEN_COMP_PGI>=1500) \
+ || (EIGEN_COMP_SUNCC>=0x5130))
+#define EIGEN_HAS_ALIGNAS 1
+#else
+#define EIGEN_HAS_ALIGNAS 0
+#endif
+#endif
+
+// Does the compiler support type_traits?
+// - full support of type traits was added only to GCC 5.1.0.
+// - 20150626 corresponds to the last release of 4.x libstdc++
+#ifndef EIGEN_HAS_TYPE_TRAITS
+#if EIGEN_MAX_CPP_VER>=11 && (EIGEN_HAS_CXX11 || EIGEN_COMP_MSVC >= 1700) \
+ && ((!EIGEN_COMP_GNUC_STRICT) || EIGEN_GNUC_AT_LEAST(5, 1)) \
+ && ((!defined(__GLIBCXX__)) || __GLIBCXX__ > 20150626)
+#define EIGEN_HAS_TYPE_TRAITS 1
+#define EIGEN_INCLUDE_TYPE_TRAITS
+#else
+#define EIGEN_HAS_TYPE_TRAITS 0
+#endif
+#endif
+
// Does the compiler support variadic templates?
#ifndef EIGEN_HAS_VARIADIC_TEMPLATES
-#if EIGEN_MAX_CPP_VER>=11 && (__cplusplus > 199711L || EIGEN_COMP_MSVC >= 1900) \
- && (!defined(__NVCC__) || !EIGEN_ARCH_ARM_OR_ARM64 || (EIGEN_CUDACC_VER >= 80000) )
+#if EIGEN_MAX_CPP_VER>=11 && (EIGEN_COMP_CXXVER >= 11) \
+ && (!defined(__NVCC__) || !EIGEN_ARCH_ARM_OR_ARM64 || (EIGEN_COMP_NVCC >= 80000) )
// ^^ Disable the use of variadic templates when compiling with versions of nvcc older than 8.0 on ARM devices:
// this prevents nvcc from crashing when compiling Eigen on Tegra X1
#define EIGEN_HAS_VARIADIC_TEMPLATES 1
+#elif EIGEN_MAX_CPP_VER>=11 && (EIGEN_COMP_CXXVER >= 11) && defined(SYCL_DEVICE_ONLY)
+#define EIGEN_HAS_VARIADIC_TEMPLATES 1
#else
#define EIGEN_HAS_VARIADIC_TEMPLATES 0
#endif
@@ -410,27 +764,33 @@
// Does the compiler fully support const expressions? (as in c++14)
#ifndef EIGEN_HAS_CONSTEXPR
+ #if defined(EIGEN_CUDACC)
+ // Const expressions are supported provided that c++11 is enabled and we're using either clang or nvcc 7.5 or above
+ #if EIGEN_MAX_CPP_VER>=14 && (EIGEN_COMP_CXXVER >= 11 && (EIGEN_COMP_CLANG || EIGEN_COMP_NVCC >= 70500))
+ #define EIGEN_HAS_CONSTEXPR 1
+ #endif
+ #elif EIGEN_MAX_CPP_VER>=14 && (__has_feature(cxx_relaxed_constexpr) || (EIGEN_COMP_CXXVER >= 14) || \
+ (EIGEN_GNUC_AT_LEAST(4,8) && (EIGEN_COMP_CXXVER >= 11)) || \
+ (EIGEN_COMP_CLANG >= 306 && (EIGEN_COMP_CXXVER >= 11)))
+ #define EIGEN_HAS_CONSTEXPR 1
+ #endif
-#ifdef __CUDACC__
-// Const expressions are supported provided that c++11 is enabled and we're using either clang or nvcc 7.5 or above
-#if EIGEN_MAX_CPP_VER>=14 && (__cplusplus > 199711L && (EIGEN_COMP_CLANG || EIGEN_CUDACC_VER >= 70500))
- #define EIGEN_HAS_CONSTEXPR 1
-#endif
-#elif EIGEN_MAX_CPP_VER>=14 && (__has_feature(cxx_relaxed_constexpr) || (defined(__cplusplus) && __cplusplus >= 201402L) || \
- (EIGEN_GNUC_AT_LEAST(4,8) && (__cplusplus > 199711L)))
-#define EIGEN_HAS_CONSTEXPR 1
-#endif
+ #ifndef EIGEN_HAS_CONSTEXPR
+ #define EIGEN_HAS_CONSTEXPR 0
+ #endif
-#ifndef EIGEN_HAS_CONSTEXPR
-#define EIGEN_HAS_CONSTEXPR 0
-#endif
+#endif // EIGEN_HAS_CONSTEXPR
+#if EIGEN_HAS_CONSTEXPR
+#define EIGEN_CONSTEXPR constexpr
+#else
+#define EIGEN_CONSTEXPR
#endif
// Does the compiler support C++11 math?
// Let's be conservative and enable the default C++11 implementation only if we are sure it exists
#ifndef EIGEN_HAS_CXX11_MATH
- #if EIGEN_MAX_CPP_VER>=11 && ((__cplusplus > 201103L) || (__cplusplus >= 201103L) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC || EIGEN_COMP_ICC) \
+ #if EIGEN_MAX_CPP_VER>=11 && ((EIGEN_COMP_CXXVER > 11) || (EIGEN_COMP_CXXVER == 11) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC || EIGEN_COMP_ICC) \
&& (EIGEN_ARCH_i386_OR_x86_64) && (EIGEN_OS_GNULINUX || EIGEN_OS_WIN_STRICT || EIGEN_OS_MAC))
#define EIGEN_HAS_CXX11_MATH 1
#else
@@ -441,9 +801,8 @@
// Does the compiler support proper C++11 containers?
#ifndef EIGEN_HAS_CXX11_CONTAINERS
#if EIGEN_MAX_CPP_VER>=11 && \
- ((__cplusplus > 201103L) \
- || ((__cplusplus >= 201103L) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_CLANG || EIGEN_COMP_ICC>=1400)) \
- || EIGEN_COMP_MSVC >= 1900)
+ ((EIGEN_COMP_CXXVER > 11) \
+ || ((EIGEN_COMP_CXXVER == 11) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC || EIGEN_COMP_ICC>=1400)))
#define EIGEN_HAS_CXX11_CONTAINERS 1
#else
#define EIGEN_HAS_CXX11_CONTAINERS 0
@@ -454,24 +813,88 @@
#ifndef EIGEN_HAS_CXX11_NOEXCEPT
#if EIGEN_MAX_CPP_VER>=11 && \
(__has_feature(cxx_noexcept) \
- || (__cplusplus > 201103L) \
- || ((__cplusplus >= 201103L) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_CLANG || EIGEN_COMP_ICC>=1400)) \
- || EIGEN_COMP_MSVC >= 1900)
+ || (EIGEN_COMP_CXXVER > 11) \
+ || ((EIGEN_COMP_CXXVER == 11) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC || EIGEN_COMP_ICC>=1400)))
#define EIGEN_HAS_CXX11_NOEXCEPT 1
#else
#define EIGEN_HAS_CXX11_NOEXCEPT 0
#endif
#endif
-/** Allows to disable some optimizations which might affect the accuracy of the result.
- * Such optimization are enabled by default, and set EIGEN_FAST_MATH to 0 to disable them.
- * They currently include:
- * - single precision ArrayBase::sin() and ArrayBase::cos() for SSE and AVX vectorization.
- */
-#ifndef EIGEN_FAST_MATH
-#define EIGEN_FAST_MATH 1
+#ifndef EIGEN_HAS_CXX11_ATOMIC
+ #if EIGEN_MAX_CPP_VER>=11 && \
+ (__has_feature(cxx_atomic) \
+ || (EIGEN_COMP_CXXVER > 11) \
+ || ((EIGEN_COMP_CXXVER == 11) && (EIGEN_COMP_MSVC==0 || EIGEN_COMP_MSVC >= 1700)))
+ #define EIGEN_HAS_CXX11_ATOMIC 1
+ #else
+ #define EIGEN_HAS_CXX11_ATOMIC 0
+ #endif
#endif
+#ifndef EIGEN_HAS_CXX11_OVERRIDE_FINAL
+ #if EIGEN_MAX_CPP_VER>=11 && \
+ (EIGEN_COMP_CXXVER >= 11 || EIGEN_COMP_MSVC >= 1700)
+ #define EIGEN_HAS_CXX11_OVERRIDE_FINAL 1
+ #else
+ #define EIGEN_HAS_CXX11_OVERRIDE_FINAL 0
+ #endif
+#endif
+
+// NOTE: the required Apple's clang version is very conservative
+// and it could be that XCode 9 works just fine.
+// NOTE: the MSVC version is based on https://en.cppreference.com/w/cpp/compiler_support
+// and not tested.
+#ifndef EIGEN_HAS_CXX17_OVERALIGN
+#if EIGEN_MAX_CPP_VER>=17 && EIGEN_COMP_CXXVER>=17 && ( \
+ (EIGEN_COMP_MSVC >= 1912) \
+ || (EIGEN_GNUC_AT_LEAST(7,0)) \
+ || ((!defined(__apple_build_version__)) && (EIGEN_COMP_CLANG>=500)) \
+ || (( defined(__apple_build_version__)) && (__apple_build_version__>=10000000)) \
+ )
+#define EIGEN_HAS_CXX17_OVERALIGN 1
+#else
+#define EIGEN_HAS_CXX17_OVERALIGN 0
+#endif
+#endif
+
+#if defined(EIGEN_CUDACC) && EIGEN_HAS_CONSTEXPR
+ // While available already with c++11, this is useful mostly starting with c++14 and relaxed constexpr rules
+ #if defined(__NVCC__)
+ // nvcc considers constexpr functions as __host__ __device__ with the option --expt-relaxed-constexpr
+ #ifdef __CUDACC_RELAXED_CONSTEXPR__
+ #define EIGEN_CONSTEXPR_ARE_DEVICE_FUNC
+ #endif
+ #elif defined(__clang__) && defined(__CUDA__) && __has_feature(cxx_relaxed_constexpr)
+ // clang++ always considers constexpr functions as implicitly __host__ __device__
+ #define EIGEN_CONSTEXPR_ARE_DEVICE_FUNC
+ #endif
+#endif
+
+// Does the compiler support the __int128 and __uint128_t extensions for 128-bit
+// integer arithmetic?
+//
+// Clang and GCC define __SIZEOF_INT128__ when these extensions are supported,
+// but we avoid using them in certain cases:
+//
+// * Building using Clang for Windows, where the Clang runtime library has
+// 128-bit support only on LP64 architectures, but Windows is LLP64.
+#ifndef EIGEN_HAS_BUILTIN_INT128
+#if defined(__SIZEOF_INT128__) && !(EIGEN_OS_WIN && EIGEN_COMP_CLANG)
+#define EIGEN_HAS_BUILTIN_INT128 1
+#else
+#define EIGEN_HAS_BUILTIN_INT128 0
+#endif
+#endif
+
+//------------------------------------------------------------------------------------------
+// Preprocessor programming helpers
+//------------------------------------------------------------------------------------------
+
+// This macro can be used to prevent from macro expansion, e.g.:
+// std::max EIGEN_NOT_A_MACRO(a,b)
+#define EIGEN_NOT_A_MACRO
+
#define EIGEN_DEBUG_VAR(x) std::cerr << #x << " = " << x << std::endl;
// concatenate two tokens
@@ -488,7 +911,7 @@
// but it still doesn't use GCC's always_inline. This is useful in (common) situations where MSVC needs forceinline
// but GCC is still doing fine with just inline.
#ifndef EIGEN_STRONG_INLINE
-#if EIGEN_COMP_MSVC || EIGEN_COMP_ICC
+#if (EIGEN_COMP_MSVC || EIGEN_COMP_ICC) && !defined(EIGEN_GPUCC)
#define EIGEN_STRONG_INLINE __forceinline
#else
#define EIGEN_STRONG_INLINE inline
@@ -503,7 +926,7 @@
// Eval.h:91: sorry, unimplemented: inlining failed in call to 'const Eigen::Eval<Derived> Eigen::MatrixBase<Scalar, Derived>::eval() const'
// : function body not available
// See also bug 1367
-#if EIGEN_GNUC_AT_LEAST(4,2)
+#if EIGEN_GNUC_AT_LEAST(4,2) && !defined(SYCL_DEVICE_ONLY)
#define EIGEN_ALWAYS_INLINE __attribute__((always_inline)) inline
#else
#define EIGEN_ALWAYS_INLINE EIGEN_STRONG_INLINE
@@ -523,12 +946,43 @@
#define EIGEN_PERMISSIVE_EXPR
#endif
+// GPU stuff
+
+// Disable some features when compiling with GPU compilers (NVCC/clang-cuda/SYCL/HIPCC)
+#if defined(EIGEN_CUDACC) || defined(SYCL_DEVICE_ONLY) || defined(EIGEN_HIPCC)
+ // Do not try asserts on device code
+ #ifndef EIGEN_NO_DEBUG
+ #define EIGEN_NO_DEBUG
+ #endif
+
+ #ifdef EIGEN_INTERNAL_DEBUGGING
+ #undef EIGEN_INTERNAL_DEBUGGING
+ #endif
+
+ #ifdef EIGEN_EXCEPTIONS
+ #undef EIGEN_EXCEPTIONS
+ #endif
+#endif
+
+#if defined(SYCL_DEVICE_ONLY)
+ #ifndef EIGEN_DONT_VECTORIZE
+ #define EIGEN_DONT_VECTORIZE
+ #endif
+ #define EIGEN_DEVICE_FUNC __attribute__((flatten)) __attribute__((always_inline))
+// All functions callable from CUDA/HIP code must be qualified with __device__
+#elif defined(EIGEN_GPUCC)
+ #define EIGEN_DEVICE_FUNC __host__ __device__
+#else
+ #define EIGEN_DEVICE_FUNC
+#endif
+
+
// this macro allows to get rid of linking errors about multiply defined functions.
// - static is not very good because it prevents definitions from different object files to be merged.
// So static causes the resulting linked executable to be bloated with multiple copies of the same function.
// - inline is not perfect either as it unwantedly hints the compiler toward inlining the function.
-#define EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
-#define EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS inline
+#define EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_DEVICE_FUNC
+#define EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_DEVICE_FUNC inline
#ifdef NDEBUG
# ifndef EIGEN_NO_DEBUG
@@ -538,7 +992,11 @@
// eigen_plain_assert is where we implement the workaround for the assert() bug in GCC <= 4.3, see bug 89
#ifdef EIGEN_NO_DEBUG
- #define eigen_plain_assert(x)
+ #ifdef SYCL_DEVICE_ONLY // used to silence the warning on SYCL device
+ #define eigen_plain_assert(x) EIGEN_UNUSED_VARIABLE(x)
+ #else
+ #define eigen_plain_assert(x)
+ #endif
#else
#if EIGEN_SAFE_TO_USE_STANDARD_ASSERT_MACRO
namespace Eigen {
@@ -612,7 +1070,7 @@
// Suppresses 'unused variable' warnings.
namespace Eigen {
namespace internal {
- template<typename T> EIGEN_DEVICE_FUNC void ignore_unused_variable(const T&) {}
+ template<typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ignore_unused_variable(const T&) {}
}
}
#define EIGEN_UNUSED_VARIABLE(var) Eigen::internal::ignore_unused_variable(var);
@@ -626,169 +1084,75 @@
#endif
-//------------------------------------------------------------------------------------------
-// Static and dynamic alignment control
+// Acts as a barrier preventing operations involving `X` from crossing. This
+// occurs, for example, in the fast rounding trick where a magic constant is
+// added then subtracted, which is otherwise compiled away with -ffast-math.
//
-// The main purpose of this section is to define EIGEN_MAX_ALIGN_BYTES and EIGEN_MAX_STATIC_ALIGN_BYTES
-// as the maximal boundary in bytes on which dynamically and statically allocated data may be alignment respectively.
-// The values of EIGEN_MAX_ALIGN_BYTES and EIGEN_MAX_STATIC_ALIGN_BYTES can be specified by the user. If not,
-// a default value is automatically computed based on architecture, compiler, and OS.
-//
-// This section also defines macros EIGEN_ALIGN_TO_BOUNDARY(N) and the shortcuts EIGEN_ALIGN{8,16,32,_MAX}
-// to be used to declare statically aligned buffers.
-//------------------------------------------------------------------------------------------
-
-
-/* EIGEN_ALIGN_TO_BOUNDARY(n) forces data to be n-byte aligned. This is used to satisfy SIMD requirements.
- * However, we do that EVEN if vectorization (EIGEN_VECTORIZE) is disabled,
- * so that vectorization doesn't affect binary compatibility.
- *
- * If we made alignment depend on whether or not EIGEN_VECTORIZE is defined, it would be impossible to link
- * vectorized and non-vectorized code.
- */
-#if (defined __CUDACC__)
- #define EIGEN_ALIGN_TO_BOUNDARY(n) __align__(n)
-#elif EIGEN_COMP_GNUC || EIGEN_COMP_PGI || EIGEN_COMP_IBM || EIGEN_COMP_ARM
- #define EIGEN_ALIGN_TO_BOUNDARY(n) __attribute__((aligned(n)))
-#elif EIGEN_COMP_MSVC
- #define EIGEN_ALIGN_TO_BOUNDARY(n) __declspec(align(n))
-#elif EIGEN_COMP_SUNCC
- // FIXME not sure about this one:
- #define EIGEN_ALIGN_TO_BOUNDARY(n) __attribute__((aligned(n)))
-#else
- #error Please tell me what is the equivalent of __attribute__((aligned(n))) for your compiler
-#endif
-
-// If the user explicitly disable vectorization, then we also disable alignment
-#if defined(EIGEN_DONT_VECTORIZE)
- #define EIGEN_IDEAL_MAX_ALIGN_BYTES 0
-#elif defined(EIGEN_VECTORIZE_AVX512)
- // 64 bytes static alignmeent is preferred only if really required
- #define EIGEN_IDEAL_MAX_ALIGN_BYTES 64
-#elif defined(__AVX__)
- // 32 bytes static alignmeent is preferred only if really required
- #define EIGEN_IDEAL_MAX_ALIGN_BYTES 32
-#else
- #define EIGEN_IDEAL_MAX_ALIGN_BYTES 16
-#endif
-
-
-// EIGEN_MIN_ALIGN_BYTES defines the minimal value for which the notion of explicit alignment makes sense
-#define EIGEN_MIN_ALIGN_BYTES 16
-
-// Defined the boundary (in bytes) on which the data needs to be aligned. Note
-// that unless EIGEN_ALIGN is defined and not equal to 0, the data may not be
-// aligned at all regardless of the value of this #define.
-
-#if (defined(EIGEN_DONT_ALIGN_STATICALLY) || defined(EIGEN_DONT_ALIGN)) && defined(EIGEN_MAX_STATIC_ALIGN_BYTES) && EIGEN_MAX_STATIC_ALIGN_BYTES>0
-#error EIGEN_MAX_STATIC_ALIGN_BYTES and EIGEN_DONT_ALIGN[_STATICALLY] are both defined with EIGEN_MAX_STATIC_ALIGN_BYTES!=0. Use EIGEN_MAX_STATIC_ALIGN_BYTES=0 as a synonym of EIGEN_DONT_ALIGN_STATICALLY.
-#endif
-
-// EIGEN_DONT_ALIGN_STATICALLY and EIGEN_DONT_ALIGN are deprectated
-// They imply EIGEN_MAX_STATIC_ALIGN_BYTES=0
-#if defined(EIGEN_DONT_ALIGN_STATICALLY) || defined(EIGEN_DONT_ALIGN)
- #ifdef EIGEN_MAX_STATIC_ALIGN_BYTES
- #undef EIGEN_MAX_STATIC_ALIGN_BYTES
- #endif
- #define EIGEN_MAX_STATIC_ALIGN_BYTES 0
-#endif
-
-#ifndef EIGEN_MAX_STATIC_ALIGN_BYTES
-
- // Try to automatically guess what is the best default value for EIGEN_MAX_STATIC_ALIGN_BYTES
-
- // 16 byte alignment is only useful for vectorization. Since it affects the ABI, we need to enable
- // 16 byte alignment on all platforms where vectorization might be enabled. In theory we could always
- // enable alignment, but it can be a cause of problems on some platforms, so we just disable it in
- // certain common platform (compiler+architecture combinations) to avoid these problems.
- // Only static alignment is really problematic (relies on nonstandard compiler extensions),
- // try to keep heap alignment even when we have to disable static alignment.
- #if EIGEN_COMP_GNUC && !(EIGEN_ARCH_i386_OR_x86_64 || EIGEN_ARCH_ARM_OR_ARM64 || EIGEN_ARCH_PPC || EIGEN_ARCH_IA64)
- #define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 1
- #elif EIGEN_ARCH_ARM_OR_ARM64 && EIGEN_COMP_GNUC_STRICT && EIGEN_GNUC_AT_MOST(4, 6)
- // Old versions of GCC on ARM, at least 4.4, were once seen to have buggy static alignment support.
- // Not sure which version fixed it, hopefully it doesn't affect 4.7, which is still somewhat in use.
- // 4.8 and newer seem definitely unaffected.
- #define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 1
+// See bug 1674
+#if !defined(EIGEN_OPTIMIZATION_BARRIER)
+ #if EIGEN_COMP_GNUC
+ // According to https://gcc.gnu.org/onlinedocs/gcc/Constraints.html:
+ // X: Any operand whatsoever.
+ // r: A register operand is allowed provided that it is in a general
+ // register.
+ // g: Any register, memory or immediate integer operand is allowed, except
+ // for registers that are not general registers.
+ // w: (AArch32/AArch64) Floating point register, Advanced SIMD vector
+ // register or SVE vector register.
+ // x: (SSE) Any SSE register.
+ // (AArch64) Like w, but restricted to registers 0 to 15 inclusive.
+ // v: (PowerPC) An Altivec vector register.
+ // wa:(PowerPC) A VSX register.
+ //
+ // "X" (uppercase) should work for all cases, though this seems to fail for
+ // some versions of GCC for arm/aarch64 with
+ // "error: inconsistent operand constraints in an 'asm'"
+ // Clang x86_64/arm/aarch64 seems to require "g" to support both scalars and
+ // vectors, otherwise
+ // "error: non-trivial scalar-to-vector conversion, possible invalid
+ // constraint for vector type"
+ //
+ // GCC for ppc64le generates an internal compiler error with x/X/g.
+ // GCC for AVX generates an internal compiler error with X.
+ //
+ // Tested on icc/gcc/clang for sse, avx, avx2, avx512dq
+ // gcc for arm, aarch64,
+ // gcc for ppc64le,
+ // both vectors and scalars.
+ //
+ // Note that this is restricted to plain types - this will not work
+ // directly for std::complex<T>, Eigen::half, Eigen::bfloat16. For these,
+ // you will need to apply to the underlying POD type.
+ #if EIGEN_ARCH_PPC && EIGEN_COMP_GNUC_STRICT
+ // This seems to be broken on clang. Packet4f is loaded into a single
+ // register rather than a vector, zeroing out some entries. Integer
+ // types also generate a compile error.
+ // General, Altivec, VSX.
+ #define EIGEN_OPTIMIZATION_BARRIER(X) __asm__ ("" : "+r,v,wa" (X));
+ #elif EIGEN_ARCH_ARM_OR_ARM64
+ // General, NEON.
+ #define EIGEN_OPTIMIZATION_BARRIER(X) __asm__ ("" : "+g,w" (X));
+ #elif EIGEN_ARCH_i386_OR_x86_64
+ // General, SSE.
+ #define EIGEN_OPTIMIZATION_BARRIER(X) __asm__ ("" : "+g,x" (X));
+ #else
+ // Not implemented for other architectures.
+ #define EIGEN_OPTIMIZATION_BARRIER(X)
+ #endif
#else
- #define EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT 0
+ // Not implemented for other compilers.
+ #define EIGEN_OPTIMIZATION_BARRIER(X)
#endif
-
- // static alignment is completely disabled with GCC 3, Sun Studio, and QCC/QNX
- #if !EIGEN_GCC_AND_ARCH_DOESNT_WANT_STACK_ALIGNMENT \
- && !EIGEN_GCC3_OR_OLDER \
- && !EIGEN_COMP_SUNCC \
- && !EIGEN_OS_QNX
- #define EIGEN_ARCH_WANTS_STACK_ALIGNMENT 1
- #else
- #define EIGEN_ARCH_WANTS_STACK_ALIGNMENT 0
- #endif
-
- #if EIGEN_ARCH_WANTS_STACK_ALIGNMENT
- #define EIGEN_MAX_STATIC_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
- #else
- #define EIGEN_MAX_STATIC_ALIGN_BYTES 0
- #endif
-
#endif
-// If EIGEN_MAX_ALIGN_BYTES is defined, then it is considered as an upper bound for EIGEN_MAX_ALIGN_BYTES
-#if defined(EIGEN_MAX_ALIGN_BYTES) && EIGEN_MAX_ALIGN_BYTES<EIGEN_MAX_STATIC_ALIGN_BYTES
-#undef EIGEN_MAX_STATIC_ALIGN_BYTES
-#define EIGEN_MAX_STATIC_ALIGN_BYTES EIGEN_MAX_ALIGN_BYTES
-#endif
-
-#if EIGEN_MAX_STATIC_ALIGN_BYTES==0 && !defined(EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT)
- #define EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT
-#endif
-
-// At this stage, EIGEN_MAX_STATIC_ALIGN_BYTES>0 is the true test whether we want to align arrays on the stack or not.
-// It takes into account both the user choice to explicitly enable/disable alignment (by settting EIGEN_MAX_STATIC_ALIGN_BYTES)
-// and the architecture config (EIGEN_ARCH_WANTS_STACK_ALIGNMENT).
-// Henceforth, only EIGEN_MAX_STATIC_ALIGN_BYTES should be used.
-
-
-// Shortcuts to EIGEN_ALIGN_TO_BOUNDARY
-#define EIGEN_ALIGN8 EIGEN_ALIGN_TO_BOUNDARY(8)
-#define EIGEN_ALIGN16 EIGEN_ALIGN_TO_BOUNDARY(16)
-#define EIGEN_ALIGN32 EIGEN_ALIGN_TO_BOUNDARY(32)
-#define EIGEN_ALIGN64 EIGEN_ALIGN_TO_BOUNDARY(64)
-#if EIGEN_MAX_STATIC_ALIGN_BYTES>0
-#define EIGEN_ALIGN_MAX EIGEN_ALIGN_TO_BOUNDARY(EIGEN_MAX_STATIC_ALIGN_BYTES)
+#if EIGEN_COMP_MSVC
+ // NOTE MSVC often gives C4127 warnings with compiletime if statements. See bug 1362.
+ // This workaround is ugly, but it does the job.
+# define EIGEN_CONST_CONDITIONAL(cond) (void)0, cond
#else
-#define EIGEN_ALIGN_MAX
+# define EIGEN_CONST_CONDITIONAL(cond) cond
#endif
-
-// Dynamic alignment control
-
-#if defined(EIGEN_DONT_ALIGN) && defined(EIGEN_MAX_ALIGN_BYTES) && EIGEN_MAX_ALIGN_BYTES>0
-#error EIGEN_MAX_ALIGN_BYTES and EIGEN_DONT_ALIGN are both defined with EIGEN_MAX_ALIGN_BYTES!=0. Use EIGEN_MAX_ALIGN_BYTES=0 as a synonym of EIGEN_DONT_ALIGN.
-#endif
-
-#ifdef EIGEN_DONT_ALIGN
- #ifdef EIGEN_MAX_ALIGN_BYTES
- #undef EIGEN_MAX_ALIGN_BYTES
- #endif
- #define EIGEN_MAX_ALIGN_BYTES 0
-#elif !defined(EIGEN_MAX_ALIGN_BYTES)
- #define EIGEN_MAX_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
-#endif
-
-#if EIGEN_IDEAL_MAX_ALIGN_BYTES > EIGEN_MAX_ALIGN_BYTES
-#define EIGEN_DEFAULT_ALIGN_BYTES EIGEN_IDEAL_MAX_ALIGN_BYTES
-#else
-#define EIGEN_DEFAULT_ALIGN_BYTES EIGEN_MAX_ALIGN_BYTES
-#endif
-
-
-#ifndef EIGEN_UNALIGNED_VECTORIZE
-#define EIGEN_UNALIGNED_VECTORIZE 1
-#endif
-
-//----------------------------------------------------------------------
-
-
#ifdef EIGEN_DONT_USE_RESTRICT_KEYWORD
#define EIGEN_RESTRICT
#endif
@@ -796,10 +1160,6 @@
#define EIGEN_RESTRICT __restrict
#endif
-#ifndef EIGEN_STACK_ALLOCATION_LIMIT
-// 131072 == 128 KB
-#define EIGEN_STACK_ALLOCATION_LIMIT 131072
-#endif
#ifndef EIGEN_DEFAULT_IO_FORMAT
#ifdef EIGEN_MAKING_DOCS
@@ -814,8 +1174,23 @@
// just an empty macro !
#define EIGEN_EMPTY
-#if EIGEN_COMP_MSVC_STRICT && (EIGEN_COMP_MSVC < 1900 || EIGEN_CUDACC_VER>0)
- // for older MSVC versions, as well as 1900 && CUDA 8, using the base operator is sufficient (cf Bugs 1000, 1324)
+
+// When compiling CUDA/HIP device code with NVCC or HIPCC
+// pull in math functions from the global namespace.
+// In host mode, and when device code is compiled with clang,
+// use the std versions.
+#if (defined(EIGEN_CUDA_ARCH) && defined(__NVCC__)) || defined(EIGEN_HIP_DEVICE_COMPILE)
+ #define EIGEN_USING_STD(FUNC) using ::FUNC;
+#else
+ #define EIGEN_USING_STD(FUNC) using std::FUNC;
+#endif
+
+#if EIGEN_COMP_MSVC_STRICT && (EIGEN_COMP_MSVC < 1900 || (EIGEN_COMP_MSVC == 1900 && EIGEN_COMP_NVCC))
+ // For older MSVC versions, as well as 1900 && CUDA 8, using the base operator is necessary,
+ // otherwise we get duplicate definition errors
+ // For later MSVC versions, we require explicit operator= definition, otherwise we get
+ // use of implicitly deleted operator errors.
+ // (cf Bugs 920, 1000, 1324, 2291)
#define EIGEN_INHERIT_ASSIGNMENT_EQUAL_OPERATOR(Derived) \
using Base::operator =;
#elif EIGEN_COMP_CLANG // workaround clang bug (see http://forum.kde.org/viewtopic.php?f=74&t=102653)
@@ -841,12 +1216,13 @@
* This is necessary, because the implicit definition is deprecated if the copy-assignment is overridden.
*/
#if EIGEN_HAS_CXX11
-#define EIGEN_DEFAULT_COPY_CONSTRUCTOR(CLASS) EIGEN_DEVICE_FUNC CLASS(const CLASS&) = default;
+#define EIGEN_DEFAULT_COPY_CONSTRUCTOR(CLASS) CLASS(const CLASS&) = default;
#else
#define EIGEN_DEFAULT_COPY_CONSTRUCTOR(CLASS)
#endif
+
/** \internal
* \brief Macro to manually inherit assignment operators.
* This is necessary, because the implicitly defined assignment operator gets deleted when a custom operator= is defined.
@@ -865,15 +1241,18 @@
*/
#if EIGEN_HAS_CXX11
#define EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(Derived) \
- EIGEN_DEVICE_FUNC Derived() = default; \
- EIGEN_DEVICE_FUNC ~Derived() = default;
+ Derived() = default; \
+ ~Derived() = default;
#else
#define EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(Derived) \
- EIGEN_DEVICE_FUNC Derived() {}; \
- /* EIGEN_DEVICE_FUNC ~Derived() {}; */
+ Derived() {}; \
+ /* ~Derived() {}; */
#endif
+
+
+
/**
* Just a side note. Commenting within defines works only by documenting
* behind the object (via '!<'). Comments cannot be multi-line and thus
@@ -889,7 +1268,8 @@
typedef typename Eigen::internal::ref_selector<Derived>::type Nested; \
typedef typename Eigen::internal::traits<Derived>::StorageKind StorageKind; \
typedef typename Eigen::internal::traits<Derived>::StorageIndex StorageIndex; \
- enum { RowsAtCompileTime = Eigen::internal::traits<Derived>::RowsAtCompileTime, \
+ enum CompileTimeTraits \
+ { RowsAtCompileTime = Eigen::internal::traits<Derived>::RowsAtCompileTime, \
ColsAtCompileTime = Eigen::internal::traits<Derived>::ColsAtCompileTime, \
Flags = Eigen::internal::traits<Derived>::Flags, \
SizeAtCompileTime = Base::SizeAtCompileTime, \
@@ -934,6 +1314,14 @@
#define EIGEN_IMPLIES(a,b) (!(a) || (b))
+#if EIGEN_HAS_BUILTIN(__builtin_expect) || EIGEN_COMP_GNUC
+#define EIGEN_PREDICT_FALSE(x) (__builtin_expect(x, false))
+#define EIGEN_PREDICT_TRUE(x) (__builtin_expect(false || (x), true))
+#else
+#define EIGEN_PREDICT_FALSE(x) (x)
+#define EIGEN_PREDICT_TRUE(x) (x)
+#endif
+
// the expression type of a standard coefficient wise binary operation
#define EIGEN_CWISE_BINARY_RETURN_TYPE(LHS,RHS,OPNAME) \
CwiseBinaryOp< \
@@ -965,14 +1353,14 @@
const typename internal::plain_constant_type<EXPR,SCALAR>::type, const EXPR>
// Workaround for MSVC 2010 (see ML thread "patch with compile for for MSVC 2010")
-#if EIGEN_COMP_MSVC_STRICT<=1600
+#if EIGEN_COMP_MSVC_STRICT && (EIGEN_COMP_MSVC_STRICT<=1600)
#define EIGEN_MSVC10_WORKAROUND_BINARYOP_RETURN_TYPE(X) typename internal::enable_if<true,X>::type
#else
#define EIGEN_MSVC10_WORKAROUND_BINARYOP_RETURN_TYPE(X) X
#endif
#define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) \
- template <typename T> EIGEN_DEVICE_FUNC inline \
+ template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
EIGEN_MSVC10_WORKAROUND_BINARYOP_RETURN_TYPE(const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<Scalar EIGEN_COMMA T EIGEN_COMMA EIGEN_SCALAR_BINARY_SUPPORTED(OPNAME,Scalar,T)>::type,OPNAME))\
(METHOD)(const T& scalar) const { \
typedef typename internal::promote_scalar_arg<Scalar,T,EIGEN_SCALAR_BINARY_SUPPORTED(OPNAME,Scalar,T)>::type PromotedT; \
@@ -981,7 +1369,7 @@
}
#define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD,OPNAME) \
- template <typename T> EIGEN_DEVICE_FUNC inline friend \
+ template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend \
EIGEN_MSVC10_WORKAROUND_BINARYOP_RETURN_TYPE(const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename internal::promote_scalar_arg<Scalar EIGEN_COMMA T EIGEN_COMMA EIGEN_SCALAR_BINARY_SUPPORTED(OPNAME,T,Scalar)>::type,Derived,OPNAME)) \
(METHOD)(const T& scalar, const StorageBaseType& matrix) { \
typedef typename internal::promote_scalar_arg<Scalar,T,EIGEN_SCALAR_BINARY_SUPPORTED(OPNAME,T,Scalar)>::type PromotedT; \
@@ -994,15 +1382,23 @@
EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME)
+#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)
+ #define EIGEN_EXCEPTIONS
+#endif
+
+
#ifdef EIGEN_EXCEPTIONS
# define EIGEN_THROW_X(X) throw X
# define EIGEN_THROW throw
# define EIGEN_TRY try
# define EIGEN_CATCH(X) catch (X)
#else
-# ifdef __CUDA_ARCH__
+# if defined(EIGEN_CUDA_ARCH)
# define EIGEN_THROW_X(X) asm("trap;")
# define EIGEN_THROW asm("trap;")
+# elif defined(EIGEN_HIP_DEVICE_COMPILE)
+# define EIGEN_THROW_X(X) asm("s_trap 0")
+# define EIGEN_THROW asm("s_trap 0")
# else
# define EIGEN_THROW_X(X) std::abort()
# define EIGEN_THROW std::abort()
@@ -1022,13 +1418,47 @@
# define EIGEN_NOEXCEPT
# define EIGEN_NOEXCEPT_IF(x)
# define EIGEN_NO_THROW throw()
-# if EIGEN_COMP_MSVC
+# if EIGEN_COMP_MSVC || EIGEN_COMP_CXXVER>=17
// MSVC does not support exception specifications (warning C4290),
- // and they are deprecated in c++11 anyway.
+ // and they are deprecated in c++11 anyway. This is even an error in c++17.
# define EIGEN_EXCEPTION_SPEC(X) throw()
# else
# define EIGEN_EXCEPTION_SPEC(X) throw(X)
# endif
#endif
+#if EIGEN_HAS_VARIADIC_TEMPLATES
+// The all function is used to enable a variadic version of eigen_assert which can take a parameter pack as its input.
+namespace Eigen {
+namespace internal {
+
+inline bool all(){ return true; }
+
+template<typename T, typename ...Ts>
+bool all(T t, Ts ... ts){ return t && all(ts...); }
+
+}
+}
+#endif
+
+#if EIGEN_HAS_CXX11_OVERRIDE_FINAL
+// provide override and final specifiers if they are available:
+# define EIGEN_OVERRIDE override
+# define EIGEN_FINAL final
+#else
+# define EIGEN_OVERRIDE
+# define EIGEN_FINAL
+#endif
+
+// Wrapping #pragma unroll in a macro since it is required for SYCL
+#if defined(SYCL_DEVICE_ONLY)
+ #if defined(_MSC_VER)
+ #define EIGEN_UNROLL_LOOP __pragma(unroll)
+ #else
+ #define EIGEN_UNROLL_LOOP _Pragma("unroll")
+ #endif
+#else
+ #define EIGEN_UNROLL_LOOP
+#endif
+
#endif // EIGEN_MACROS_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Memory.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Memory.h
index 291383c..875318c 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Memory.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Memory.h
@@ -63,14 +63,28 @@
namespace internal {
-EIGEN_DEVICE_FUNC
+EIGEN_DEVICE_FUNC
inline void throw_std_bad_alloc()
{
#ifdef EIGEN_EXCEPTIONS
throw std::bad_alloc();
#else
std::size_t huge = static_cast<std::size_t>(-1);
- ::operator new(huge);
+ #if defined(EIGEN_HIPCC)
+ //
+ // calls to "::operator new" are to be treated as opaque function calls (i.e no inlining),
+ // and as a consequence the code in the #else block triggers the hipcc warning :
+ // "no overloaded function has restriction specifiers that are compatible with the ambient context"
+ //
+ // "throw_std_bad_alloc" has the EIGEN_DEVICE_FUNC attribute, so it seems that hipcc expects
+ // the same on "operator new"
+ // Reverting code back to the old version in this #if block for the hipcc compiler
+ //
+ new int[huge];
+ #else
+ void* unused = ::operator new(huge);
+ EIGEN_UNUSED_VARIABLE(unused);
+ #endif
#endif
}
@@ -83,19 +97,26 @@
/** \internal Like malloc, but the returned pointer is guaranteed to be 16-byte aligned.
* Fast, but wastes 16 additional bytes of memory. Does not throw any exception.
*/
-inline void* handmade_aligned_malloc(std::size_t size)
+EIGEN_DEVICE_FUNC inline void* handmade_aligned_malloc(std::size_t size, std::size_t alignment = EIGEN_DEFAULT_ALIGN_BYTES)
{
- void *original = std::malloc(size+EIGEN_DEFAULT_ALIGN_BYTES);
+ eigen_assert(alignment >= sizeof(void*) && (alignment & (alignment-1)) == 0 && "Alignment must be at least sizeof(void*) and a power of 2");
+
+ EIGEN_USING_STD(malloc)
+ void *original = malloc(size+alignment);
+
if (original == 0) return 0;
- void *aligned = reinterpret_cast<void*>((reinterpret_cast<std::size_t>(original) & ~(std::size_t(EIGEN_DEFAULT_ALIGN_BYTES-1))) + EIGEN_DEFAULT_ALIGN_BYTES);
+ void *aligned = reinterpret_cast<void*>((reinterpret_cast<std::size_t>(original) & ~(std::size_t(alignment-1))) + alignment);
*(reinterpret_cast<void**>(aligned) - 1) = original;
return aligned;
}
/** \internal Frees memory allocated with handmade_aligned_malloc */
-inline void handmade_aligned_free(void *ptr)
+EIGEN_DEVICE_FUNC inline void handmade_aligned_free(void *ptr)
{
- if (ptr) std::free(*(reinterpret_cast<void**>(ptr) - 1));
+ if (ptr) {
+ EIGEN_USING_STD(free)
+ free(*(reinterpret_cast<void**>(ptr) - 1));
+ }
}
/** \internal
@@ -114,7 +135,7 @@
void *previous_aligned = static_cast<char *>(original)+previous_offset;
if(aligned!=previous_aligned)
std::memmove(aligned, previous_aligned, size);
-
+
*(reinterpret_cast<void**>(aligned) - 1) = original;
return aligned;
}
@@ -142,7 +163,7 @@
{
eigen_assert(is_malloc_allowed() && "heap allocation is forbidden (EIGEN_RUNTIME_NO_MALLOC is defined and g_is_malloc_allowed is false)");
}
-#else
+#else
EIGEN_DEVICE_FUNC inline void check_that_malloc_is_allowed()
{}
#endif
@@ -156,9 +177,12 @@
void *result;
#if (EIGEN_DEFAULT_ALIGN_BYTES==0) || EIGEN_MALLOC_ALREADY_ALIGNED
- result = std::malloc(size);
+
+ EIGEN_USING_STD(malloc)
+ result = malloc(size);
+
#if EIGEN_DEFAULT_ALIGN_BYTES==16
- eigen_assert((size<16 || (std::size_t(result)%16)==0) && "System's malloc returned an unaligned pointer. Compile with EIGEN_MALLOC_ALREADY_ALIGNED=0 to fallback to handmade alignd memory allocator.");
+ eigen_assert((size<16 || (std::size_t(result)%16)==0) && "System's malloc returned an unaligned pointer. Compile with EIGEN_MALLOC_ALREADY_ALIGNED=0 to fallback to handmade aligned memory allocator.");
#endif
#else
result = handmade_aligned_malloc(size);
@@ -174,7 +198,10 @@
EIGEN_DEVICE_FUNC inline void aligned_free(void *ptr)
{
#if (EIGEN_DEFAULT_ALIGN_BYTES==0) || EIGEN_MALLOC_ALREADY_ALIGNED
- std::free(ptr);
+
+ EIGEN_USING_STD(free)
+ free(ptr);
+
#else
handmade_aligned_free(ptr);
#endif
@@ -187,7 +214,7 @@
*/
inline void* aligned_realloc(void *ptr, std::size_t new_size, std::size_t old_size)
{
- EIGEN_UNUSED_VARIABLE(old_size);
+ EIGEN_UNUSED_VARIABLE(old_size)
void *result;
#if (EIGEN_DEFAULT_ALIGN_BYTES==0) || EIGEN_MALLOC_ALREADY_ALIGNED
@@ -218,7 +245,9 @@
{
check_that_malloc_is_allowed();
- void *result = std::malloc(size);
+ EIGEN_USING_STD(malloc)
+ void *result = malloc(size);
+
if(!result && size)
throw_std_bad_alloc();
return result;
@@ -232,7 +261,8 @@
template<> EIGEN_DEVICE_FUNC inline void conditional_aligned_free<false>(void *ptr)
{
- std::free(ptr);
+ EIGEN_USING_STD(free)
+ free(ptr);
}
template<bool Align> inline void* conditional_aligned_realloc(void* ptr, std::size_t new_size, std::size_t old_size)
@@ -331,7 +361,7 @@
template<typename T> EIGEN_DEVICE_FUNC inline void aligned_delete(T *ptr, std::size_t size)
{
destruct_elements_of_array<T>(ptr, size);
- aligned_free(ptr);
+ Eigen::internal::aligned_free(ptr);
}
/** \internal Deletes objects constructed with conditional_aligned_new
@@ -471,8 +501,8 @@
}
/** \internal Returns the smallest integer multiple of \a base and greater or equal to \a size
- */
-template<typename Index>
+ */
+template<typename Index>
inline Index first_multiple(Index size, Index base)
{
return ((size+base-1)/base)*base;
@@ -493,7 +523,8 @@
IntPtr size = IntPtr(end)-IntPtr(start);
if(size==0) return;
eigen_internal_assert(start!=0 && end!=0 && target!=0);
- std::memcpy(target, start, size);
+ EIGEN_USING_STD(memcpy)
+ memcpy(target, start, size);
}
};
@@ -502,7 +533,7 @@
{ std::copy(start, end, target); }
};
-// intelligent memmove. falls back to std::memmove for POD types, uses std::copy otherwise.
+// intelligent memmove. falls back to std::memmove for POD types, uses std::copy otherwise.
template<typename T, bool UseMemmove> struct smart_memmove_helper;
template<typename T> void smart_memmove(const T* start, const T* end, T* target)
@@ -522,19 +553,30 @@
template<typename T> struct smart_memmove_helper<T,false> {
static inline void run(const T* start, const T* end, T* target)
- {
+ {
if (UIntPtr(target) < UIntPtr(start))
{
std::copy(start, end, target);
}
- else
+ else
{
std::ptrdiff_t count = (std::ptrdiff_t(end)-std::ptrdiff_t(start)) / sizeof(T);
- std::copy_backward(start, end, target + count);
+ std::copy_backward(start, end, target + count);
}
}
};
+#if EIGEN_HAS_RVALUE_REFERENCES
+template<typename T> EIGEN_DEVICE_FUNC T* smart_move(T* start, T* end, T* target)
+{
+ return std::move(start, end, target);
+}
+#else
+template<typename T> EIGEN_DEVICE_FUNC T* smart_move(T* start, T* end, T* target)
+{
+ return std::copy(start, end, target);
+}
+#endif
/*****************************************************************************
*** Implementation of runtime stack allocation (falling back to malloc) ***
@@ -542,7 +584,7 @@
// you can overwrite Eigen's default behavior regarding alloca by defining EIGEN_ALLOCA
// to the appropriate stack allocation function
-#ifndef EIGEN_ALLOCA
+#if ! defined EIGEN_ALLOCA && ! defined EIGEN_GPU_COMPILE_PHASE
#if EIGEN_OS_LINUX || EIGEN_OS_MAC || (defined alloca)
#define EIGEN_ALLOCA alloca
#elif EIGEN_COMP_MSVC
@@ -550,6 +592,15 @@
#endif
#endif
+// With clang -Oz -mthumb, alloca changes the stack pointer in a way that is
+// not allowed in Thumb2. -DEIGEN_STACK_ALLOCATION_LIMIT=0 doesn't work because
+// the compiler still emits bad code because stack allocation checks use "<=".
+// TODO: Eliminate after https://bugs.llvm.org/show_bug.cgi?id=23772
+// is fixed.
+#if defined(__clang__) && defined(__thumb__)
+ #undef EIGEN_ALLOCA
+#endif
+
// This helper class construct the allocated memory, and takes care of destructing and freeing the handled data
// at destruction time. In practice this helper class is mainly useful to avoid memory leak in case of exceptions.
template<typename T> class aligned_stack_memory_handler : noncopyable
@@ -561,12 +612,14 @@
* In this case, the buffer elements will also be destructed when this handler will be destructed.
* Finally, if \a dealloc is true, then the pointer \a ptr is freed.
**/
+ EIGEN_DEVICE_FUNC
aligned_stack_memory_handler(T* ptr, std::size_t size, bool dealloc)
: m_ptr(ptr), m_size(size), m_deallocate(dealloc)
{
if(NumTraits<T>::RequireInitialization && m_ptr)
Eigen::internal::construct_elements_of_array(m_ptr, size);
}
+ EIGEN_DEVICE_FUNC
~aligned_stack_memory_handler()
{
if(NumTraits<T>::RequireInitialization && m_ptr)
@@ -580,6 +633,60 @@
bool m_deallocate;
};
+#ifdef EIGEN_ALLOCA
+
+template<typename Xpr, int NbEvaluations,
+ bool MapExternalBuffer = nested_eval<Xpr,NbEvaluations>::Evaluate && Xpr::MaxSizeAtCompileTime==Dynamic
+ >
+struct local_nested_eval_wrapper
+{
+ static const bool NeedExternalBuffer = false;
+ typedef typename Xpr::Scalar Scalar;
+ typedef typename nested_eval<Xpr,NbEvaluations>::type ObjectType;
+ ObjectType object;
+
+ EIGEN_DEVICE_FUNC
+ local_nested_eval_wrapper(const Xpr& xpr, Scalar* ptr) : object(xpr)
+ {
+ EIGEN_UNUSED_VARIABLE(ptr);
+ eigen_internal_assert(ptr==0);
+ }
+};
+
+template<typename Xpr, int NbEvaluations>
+struct local_nested_eval_wrapper<Xpr,NbEvaluations,true>
+{
+ static const bool NeedExternalBuffer = true;
+ typedef typename Xpr::Scalar Scalar;
+ typedef typename plain_object_eval<Xpr>::type PlainObject;
+ typedef Map<PlainObject,EIGEN_DEFAULT_ALIGN_BYTES> ObjectType;
+ ObjectType object;
+
+ EIGEN_DEVICE_FUNC
+ local_nested_eval_wrapper(const Xpr& xpr, Scalar* ptr)
+ : object(ptr==0 ? reinterpret_cast<Scalar*>(Eigen::internal::aligned_malloc(sizeof(Scalar)*xpr.size())) : ptr, xpr.rows(), xpr.cols()),
+ m_deallocate(ptr==0)
+ {
+ if(NumTraits<Scalar>::RequireInitialization && object.data())
+ Eigen::internal::construct_elements_of_array(object.data(), object.size());
+ object = xpr;
+ }
+
+ EIGEN_DEVICE_FUNC
+ ~local_nested_eval_wrapper()
+ {
+ if(NumTraits<Scalar>::RequireInitialization && object.data())
+ Eigen::internal::destruct_elements_of_array(object.data(), object.size());
+ if(m_deallocate)
+ Eigen::internal::aligned_free(object.data());
+ }
+
+private:
+ bool m_deallocate;
+};
+
+#endif // EIGEN_ALLOCA
+
template<typename T> class scoped_array : noncopyable
{
T* m_ptr;
@@ -603,13 +710,15 @@
{
std::swap(a.ptr(),b.ptr());
}
-
+
} // end namespace internal
/** \internal
- * Declares, allocates and construct an aligned buffer named NAME of SIZE elements of type TYPE on the stack
- * if SIZE is smaller than EIGEN_STACK_ALLOCATION_LIMIT, and if stack allocation is supported by the platform
- * (currently, this is Linux and Visual Studio only). Otherwise the memory is allocated on the heap.
+ *
+ * The macro ei_declare_aligned_stack_constructed_variable(TYPE,NAME,SIZE,BUFFER) declares, allocates,
+ * and construct an aligned buffer named NAME of SIZE elements of type TYPE on the stack
+ * if the size in bytes is smaller than EIGEN_STACK_ALLOCATION_LIMIT, and if stack allocation is supported by the platform
+ * (currently, this is Linux, OSX and Visual Studio only). Otherwise the memory is allocated on the heap.
* The allocated buffer is automatically deleted when exiting the scope of this declaration.
* If BUFFER is non null, then the declared variable is simply an alias for BUFFER, and no allocation/deletion occurs.
* Here is an example:
@@ -620,9 +729,17 @@
* }
* \endcode
* The underlying stack allocation function can controlled with the EIGEN_ALLOCA preprocessor token.
+ *
+ * The macro ei_declare_local_nested_eval(XPR_T,XPR,N,NAME) is analogue to
+ * \code
+ * typename internal::nested_eval<XPRT_T,N>::type NAME(XPR);
+ * \endcode
+ * with the advantage of using aligned stack allocation even if the maximal size of XPR at compile time is unknown.
+ * This is accomplished through alloca if this later is supported and if the required number of bytes
+ * is below EIGEN_STACK_ALLOCATION_LIMIT.
*/
#ifdef EIGEN_ALLOCA
-
+
#if EIGEN_DEFAULT_ALIGN_BYTES>0
// We always manually re-align the result of EIGEN_ALLOCA.
// If alloca is already aligned, the compiler should be smart enough to optimize away the re-alignment.
@@ -639,13 +756,23 @@
: Eigen::internal::aligned_malloc(sizeof(TYPE)*SIZE) ); \
Eigen::internal::aligned_stack_memory_handler<TYPE> EIGEN_CAT(NAME,_stack_memory_destructor)((BUFFER)==0 ? NAME : 0,SIZE,sizeof(TYPE)*SIZE>EIGEN_STACK_ALLOCATION_LIMIT)
+
+ #define ei_declare_local_nested_eval(XPR_T,XPR,N,NAME) \
+ Eigen::internal::local_nested_eval_wrapper<XPR_T,N> EIGEN_CAT(NAME,_wrapper)(XPR, reinterpret_cast<typename XPR_T::Scalar*>( \
+ ( (Eigen::internal::local_nested_eval_wrapper<XPR_T,N>::NeedExternalBuffer) && ((sizeof(typename XPR_T::Scalar)*XPR.size())<=EIGEN_STACK_ALLOCATION_LIMIT) ) \
+ ? EIGEN_ALIGNED_ALLOCA( sizeof(typename XPR_T::Scalar)*XPR.size() ) : 0 ) ) ; \
+ typename Eigen::internal::local_nested_eval_wrapper<XPR_T,N>::ObjectType NAME(EIGEN_CAT(NAME,_wrapper).object)
+
#else
#define ei_declare_aligned_stack_constructed_variable(TYPE,NAME,SIZE,BUFFER) \
Eigen::internal::check_size_for_overflow<TYPE>(SIZE); \
TYPE* NAME = (BUFFER)!=0 ? BUFFER : reinterpret_cast<TYPE*>(Eigen::internal::aligned_malloc(sizeof(TYPE)*SIZE)); \
Eigen::internal::aligned_stack_memory_handler<TYPE> EIGEN_CAT(NAME,_stack_memory_destructor)((BUFFER)==0 ? NAME : 0,SIZE,true)
-
+
+
+#define ei_declare_local_nested_eval(XPR_T,XPR,N,NAME) typename Eigen::internal::nested_eval<XPR_T,N>::type NAME(XPR)
+
#endif
@@ -653,32 +780,56 @@
*** Implementation of EIGEN_MAKE_ALIGNED_OPERATOR_NEW [_IF] ***
*****************************************************************************/
-#if EIGEN_MAX_ALIGN_BYTES!=0
+#if EIGEN_HAS_CXX17_OVERALIGN
+
+// C++17 -> no need to bother about alignment anymore :)
+
+#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_NOTHROW(NeedsToAlign)
+#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF(NeedsToAlign)
+#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW
+#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF_VECTORIZABLE_FIXED_SIZE(Scalar,Size)
+
+#else
+
+// HIP does not support new/delete on device.
+#if EIGEN_MAX_ALIGN_BYTES!=0 && !defined(EIGEN_HIP_DEVICE_COMPILE)
#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_NOTHROW(NeedsToAlign) \
+ EIGEN_DEVICE_FUNC \
void* operator new(std::size_t size, const std::nothrow_t&) EIGEN_NO_THROW { \
EIGEN_TRY { return Eigen::internal::conditional_aligned_malloc<NeedsToAlign>(size); } \
EIGEN_CATCH (...) { return 0; } \
}
#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF(NeedsToAlign) \
+ EIGEN_DEVICE_FUNC \
void *operator new(std::size_t size) { \
return Eigen::internal::conditional_aligned_malloc<NeedsToAlign>(size); \
} \
+ EIGEN_DEVICE_FUNC \
void *operator new[](std::size_t size) { \
return Eigen::internal::conditional_aligned_malloc<NeedsToAlign>(size); \
} \
+ EIGEN_DEVICE_FUNC \
void operator delete(void * ptr) EIGEN_NO_THROW { Eigen::internal::conditional_aligned_free<NeedsToAlign>(ptr); } \
+ EIGEN_DEVICE_FUNC \
void operator delete[](void * ptr) EIGEN_NO_THROW { Eigen::internal::conditional_aligned_free<NeedsToAlign>(ptr); } \
+ EIGEN_DEVICE_FUNC \
void operator delete(void * ptr, std::size_t /* sz */) EIGEN_NO_THROW { Eigen::internal::conditional_aligned_free<NeedsToAlign>(ptr); } \
+ EIGEN_DEVICE_FUNC \
void operator delete[](void * ptr, std::size_t /* sz */) EIGEN_NO_THROW { Eigen::internal::conditional_aligned_free<NeedsToAlign>(ptr); } \
/* in-place new and delete. since (at least afaik) there is no actual */ \
/* memory allocated we can safely let the default implementation handle */ \
/* this particular case. */ \
+ EIGEN_DEVICE_FUNC \
static void *operator new(std::size_t size, void *ptr) { return ::operator new(size,ptr); } \
+ EIGEN_DEVICE_FUNC \
static void *operator new[](std::size_t size, void* ptr) { return ::operator new[](size,ptr); } \
+ EIGEN_DEVICE_FUNC \
void operator delete(void * memory, void *ptr) EIGEN_NO_THROW { return ::operator delete(memory,ptr); } \
+ EIGEN_DEVICE_FUNC \
void operator delete[](void * memory, void *ptr) EIGEN_NO_THROW { return ::operator delete[](memory,ptr); } \
/* nothrow-new (returns zero instead of std::bad_alloc) */ \
EIGEN_MAKE_ALIGNED_OPERATOR_NEW_NOTHROW(NeedsToAlign) \
+ EIGEN_DEVICE_FUNC \
void operator delete(void *ptr, const std::nothrow_t&) EIGEN_NO_THROW { \
Eigen::internal::conditional_aligned_free<NeedsToAlign>(ptr); \
} \
@@ -688,8 +839,14 @@
#endif
#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF(true)
-#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF_VECTORIZABLE_FIXED_SIZE(Scalar,Size) \
- EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF(bool(((Size)!=Eigen::Dynamic) && ((sizeof(Scalar)*(Size))%EIGEN_MAX_ALIGN_BYTES==0)))
+#define EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF_VECTORIZABLE_FIXED_SIZE(Scalar,Size) \
+ EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF(bool( \
+ ((Size)!=Eigen::Dynamic) && \
+ (((EIGEN_MAX_ALIGN_BYTES>=16) && ((sizeof(Scalar)*(Size))%(EIGEN_MAX_ALIGN_BYTES )==0)) || \
+ ((EIGEN_MAX_ALIGN_BYTES>=32) && ((sizeof(Scalar)*(Size))%(EIGEN_MAX_ALIGN_BYTES/2)==0)) || \
+ ((EIGEN_MAX_ALIGN_BYTES>=64) && ((sizeof(Scalar)*(Size))%(EIGEN_MAX_ALIGN_BYTES/4)==0)) )))
+
+#endif
/****************************************************************************/
@@ -703,13 +860,13 @@
* - 32 bytes alignment if AVX is enabled.
* - 64 bytes alignment if AVX512 is enabled.
*
-* This can be controled using the \c EIGEN_MAX_ALIGN_BYTES macro as documented
+* This can be controlled using the \c EIGEN_MAX_ALIGN_BYTES macro as documented
* \link TopicPreprocessorDirectivesPerformance there \endlink.
*
* Example:
* \code
* // Matrix4f requires 16 bytes alignment:
-* std::map< int, Matrix4f, std::less<int>,
+* std::map< int, Matrix4f, std::less<int>,
* aligned_allocator<std::pair<const int, Matrix4f> > > my_map_mat4;
* // Vector3f does not require 16 bytes alignment, no need to use Eigen's allocator:
* std::map< int, Vector3f > my_map_vec3;
@@ -744,18 +901,19 @@
~aligned_allocator() {}
+ #if EIGEN_COMP_GNUC_STRICT && EIGEN_GNUC_AT_LEAST(7,0)
+ // In gcc std::allocator::max_size() is bugged making gcc triggers a warning:
+ // eigen/Eigen/src/Core/util/Memory.h:189:12: warning: argument 1 value '18446744073709551612' exceeds maximum object size 9223372036854775807
+ // See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87544
+ size_type max_size() const {
+ return (std::numeric_limits<std::ptrdiff_t>::max)()/sizeof(T);
+ }
+ #endif
+
pointer allocate(size_type num, const void* /*hint*/ = 0)
{
internal::check_size_for_overflow<T>(num);
- size_type size = num * sizeof(T);
-#if EIGEN_COMP_GNUC_STRICT && EIGEN_GNUC_AT_LEAST(7,0)
- // workaround gcc bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87544
- // It triggered eigen/Eigen/src/Core/util/Memory.h:189:12: warning: argument 1 value '18446744073709551612' exceeds maximum object size 9223372036854775807
- if(size>=std::size_t((std::numeric_limits<std::ptrdiff_t>::max)()))
- return 0;
- else
-#endif
- return static_cast<pointer>( internal::aligned_malloc(size) );
+ return static_cast<pointer>( internal::aligned_malloc(num * sizeof(T)) );
}
void deallocate(pointer p, size_type /*num*/)
@@ -914,20 +1072,32 @@
{
if(max_std_funcs>=4)
queryCacheSizes_intel_direct(l1,l2,l3);
- else
+ else if(max_std_funcs>=2)
queryCacheSizes_intel_codes(l1,l2,l3);
+ else
+ l1 = l2 = l3 = 0;
}
inline void queryCacheSizes_amd(int& l1, int& l2, int& l3)
{
int abcd[4];
abcd[0] = abcd[1] = abcd[2] = abcd[3] = 0;
- EIGEN_CPUID(abcd,0x80000005,0);
- l1 = (abcd[2] >> 24) * 1024; // C[31:24] = L1 size in KB
- abcd[0] = abcd[1] = abcd[2] = abcd[3] = 0;
- EIGEN_CPUID(abcd,0x80000006,0);
- l2 = (abcd[2] >> 16) * 1024; // C[31;16] = l2 cache size in KB
- l3 = ((abcd[3] & 0xFFFC000) >> 18) * 512 * 1024; // D[31;18] = l3 cache size in 512KB
+
+ // First query the max supported function.
+ EIGEN_CPUID(abcd,0x80000000,0);
+ if(static_cast<numext::uint32_t>(abcd[0]) >= static_cast<numext::uint32_t>(0x80000006))
+ {
+ EIGEN_CPUID(abcd,0x80000005,0);
+ l1 = (abcd[2] >> 24) * 1024; // C[31:24] = L1 size in KB
+ abcd[0] = abcd[1] = abcd[2] = abcd[3] = 0;
+ EIGEN_CPUID(abcd,0x80000006,0);
+ l2 = (abcd[2] >> 16) * 1024; // C[31;16] = l2 cache size in KB
+ l3 = ((abcd[3] & 0xFFFC000) >> 18) * 512 * 1024; // D[31;18] = l3 cache size in 512KB
+ }
+ else
+ {
+ l1 = l2 = l3 = 0;
+ }
}
#endif
@@ -943,7 +1113,7 @@
// identify the CPU vendor
EIGEN_CPUID(abcd,0x0,0);
- int max_std_funcs = abcd[1];
+ int max_std_funcs = abcd[0];
if(cpuid_is_vendor(abcd,GenuineIntel))
queryCacheSizes_intel(l1,l2,l3,max_std_funcs);
else if(cpuid_is_vendor(abcd,AuthenticAMD) || cpuid_is_vendor(abcd,AMDisbetter_))
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Meta.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Meta.h
index d31e954..7badfdc 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Meta.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/Meta.h
@@ -11,13 +11,54 @@
#ifndef EIGEN_META_H
#define EIGEN_META_H
-#if defined(__CUDA_ARCH__)
-#include <cfloat>
-#include <math_constants.h>
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+
+ #include <cfloat>
+
+ #if defined(EIGEN_CUDA_ARCH)
+ #include <math_constants.h>
+ #endif
+
+ #if defined(EIGEN_HIP_DEVICE_COMPILE)
+ // #include "Eigen/src/Core/arch/HIP/hcc/math_constants.h"
+ #endif
+
#endif
-#if EIGEN_COMP_ICC>=1600 && __cplusplus >= 201103L
+// Recent versions of ICC require <cstdint> for pointer types below.
+#define EIGEN_ICC_NEEDS_CSTDINT (EIGEN_COMP_ICC>=1600 && EIGEN_COMP_CXXVER >= 11)
+
+// Define portable (u)int{32,64} types
+#if EIGEN_HAS_CXX11 || EIGEN_ICC_NEEDS_CSTDINT
#include <cstdint>
+namespace Eigen {
+namespace numext {
+typedef std::uint8_t uint8_t;
+typedef std::int8_t int8_t;
+typedef std::uint16_t uint16_t;
+typedef std::int16_t int16_t;
+typedef std::uint32_t uint32_t;
+typedef std::int32_t int32_t;
+typedef std::uint64_t uint64_t;
+typedef std::int64_t int64_t;
+}
+}
+#else
+// Without c++11, all compilers able to compile Eigen also
+// provide the C99 stdint.h header file.
+#include <stdint.h>
+namespace Eigen {
+namespace numext {
+typedef ::uint8_t uint8_t;
+typedef ::int8_t int8_t;
+typedef ::uint16_t uint16_t;
+typedef ::int16_t int16_t;
+typedef ::uint32_t uint32_t;
+typedef ::int32_t int32_t;
+typedef ::uint64_t uint64_t;
+typedef ::int64_t int64_t;
+}
+}
#endif
namespace Eigen {
@@ -43,26 +84,33 @@
// Only recent versions of ICC complain about using ptrdiff_t to hold pointers,
// and older versions do not provide *intptr_t types.
-#if EIGEN_COMP_ICC>=1600 && __cplusplus >= 201103L
+#if EIGEN_ICC_NEEDS_CSTDINT
typedef std::intptr_t IntPtr;
typedef std::uintptr_t UIntPtr;
#else
typedef std::ptrdiff_t IntPtr;
typedef std::size_t UIntPtr;
#endif
+#undef EIGEN_ICC_NEEDS_CSTDINT
struct true_type { enum { value = 1 }; };
struct false_type { enum { value = 0 }; };
+template<bool Condition>
+struct bool_constant;
+
+template<>
+struct bool_constant<true> : true_type {};
+
+template<>
+struct bool_constant<false> : false_type {};
+
template<bool Condition, typename Then, typename Else>
struct conditional { typedef Then type; };
template<typename Then, typename Else>
struct conditional <false, Then, Else> { typedef Else type; };
-template<typename T, typename U> struct is_same { enum { value = 0 }; };
-template<typename T> struct is_same<T,T> { enum { value = 1 }; };
-
template<typename T> struct remove_reference { typedef T type; };
template<typename T> struct remove_reference<T&> { typedef T type; };
@@ -97,17 +145,33 @@
template<> struct is_arithmetic<signed long> { enum { value = true }; };
template<> struct is_arithmetic<unsigned long> { enum { value = true }; };
-template<typename T> struct is_integral { enum { value = false }; };
-template<> struct is_integral<bool> { enum { value = true }; };
-template<> struct is_integral<char> { enum { value = true }; };
-template<> struct is_integral<signed char> { enum { value = true }; };
-template<> struct is_integral<unsigned char> { enum { value = true }; };
-template<> struct is_integral<signed short> { enum { value = true }; };
-template<> struct is_integral<unsigned short> { enum { value = true }; };
-template<> struct is_integral<signed int> { enum { value = true }; };
-template<> struct is_integral<unsigned int> { enum { value = true }; };
-template<> struct is_integral<signed long> { enum { value = true }; };
-template<> struct is_integral<unsigned long> { enum { value = true }; };
+template<typename T, typename U> struct is_same { enum { value = 0 }; };
+template<typename T> struct is_same<T,T> { enum { value = 1 }; };
+
+template< class T >
+struct is_void : is_same<void, typename remove_const<T>::type> {};
+
+#if EIGEN_HAS_CXX11
+template<> struct is_arithmetic<signed long long> { enum { value = true }; };
+template<> struct is_arithmetic<unsigned long long> { enum { value = true }; };
+using std::is_integral;
+#else
+template<typename T> struct is_integral { enum { value = false }; };
+template<> struct is_integral<bool> { enum { value = true }; };
+template<> struct is_integral<char> { enum { value = true }; };
+template<> struct is_integral<signed char> { enum { value = true }; };
+template<> struct is_integral<unsigned char> { enum { value = true }; };
+template<> struct is_integral<signed short> { enum { value = true }; };
+template<> struct is_integral<unsigned short> { enum { value = true }; };
+template<> struct is_integral<signed int> { enum { value = true }; };
+template<> struct is_integral<unsigned int> { enum { value = true }; };
+template<> struct is_integral<signed long> { enum { value = true }; };
+template<> struct is_integral<unsigned long> { enum { value = true }; };
+#if EIGEN_COMP_MSVC
+template<> struct is_integral<signed __int64> { enum { value = true }; };
+template<> struct is_integral<unsigned __int64> { enum { value = true }; };
+#endif
+#endif
#if EIGEN_HAS_CXX11
using std::make_unsigned;
@@ -129,6 +193,16 @@
template<> struct make_unsigned<signed __int64> { typedef unsigned __int64 type; };
template<> struct make_unsigned<unsigned __int64> { typedef unsigned __int64 type; };
#endif
+
+// Some platforms define int64_t as `long long` even for C++03, where
+// `long long` is not guaranteed by the standard. In this case we are missing
+// the definition for make_unsigned. If we just define it, we run into issues
+// where `long long` doesn't exist in some compilers for C++03. We therefore add
+// the specialization for these platforms only.
+#if EIGEN_OS_MAC || EIGEN_COMP_MINGW
+template<> struct make_unsigned<unsigned long long> { typedef unsigned long long type; };
+template<> struct make_unsigned<long long> { typedef unsigned long long type; };
+#endif
#endif
template <typename T> struct add_const { typedef const T type; };
@@ -143,6 +217,11 @@
template<typename T> struct add_const_on_value_type<T* const> { typedef T const* const type; };
template<typename T> struct add_const_on_value_type<T const* const> { typedef T const* const type; };
+#if EIGEN_HAS_CXX11
+
+using std::is_convertible;
+
+#else
template<typename From, typename To>
struct is_convertible_impl
@@ -156,16 +235,19 @@
struct yes {int a[1];};
struct no {int a[2];};
- static yes test(const To&, int);
+ template<typename T>
+ static yes test(T, int);
+
+ template<typename T>
static no test(any_conversion, ...);
public:
- static From ms_from;
+ static typename internal::remove_reference<From>::type* ms_from;
#ifdef __INTEL_COMPILER
#pragma warning push
#pragma warning ( disable : 2259 )
#endif
- enum { value = sizeof(test(ms_from, 0))==sizeof(yes) };
+ enum { value = sizeof(test<To>(*ms_from, 0))==sizeof(yes) };
#ifdef __INTEL_COMPILER
#pragma warning pop
#endif
@@ -174,10 +256,17 @@
template<typename From, typename To>
struct is_convertible
{
- enum { value = is_convertible_impl<typename remove_all<From>::type,
- typename remove_all<To >::type>::value };
+ enum { value = is_convertible_impl<From,To>::value };
};
+template<typename T>
+struct is_convertible<T,T&> { enum { value = false }; };
+
+template<typename T>
+struct is_convertible<const T,const T&> { enum { value = true }; };
+
+#endif
+
/** \internal Allows to enable/disable an overload
* according to a compile time condition.
*/
@@ -186,7 +275,7 @@
template<typename T> struct enable_if<true,T>
{ typedef T type; };
-#if defined(__CUDA_ARCH__)
+#if defined(EIGEN_GPU_COMPILE_PHASE) && !EIGEN_HAS_CXX11
#if !defined(__FLT_EPSILON__)
#define __FLT_EPSILON__ FLT_EPSILON
#define __DBL_EPSILON__ DBL_EPSILON
@@ -197,7 +286,7 @@
template<typename T> struct numeric_limits
{
EIGEN_DEVICE_FUNC
- static T epsilon() { return 0; }
+ static EIGEN_CONSTEXPR T epsilon() { return 0; }
static T (max)() { assert(false && "Highest not supported for this type"); }
static T (min)() { assert(false && "Lowest not supported for this type"); }
static T infinity() { assert(false && "Infinity not supported for this type"); }
@@ -205,91 +294,130 @@
};
template<> struct numeric_limits<float>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static float epsilon() { return __FLT_EPSILON__; }
EIGEN_DEVICE_FUNC
- static float (max)() { return CUDART_MAX_NORMAL_F; }
- EIGEN_DEVICE_FUNC
+ static float (max)() {
+ #if defined(EIGEN_CUDA_ARCH)
+ return CUDART_MAX_NORMAL_F;
+ #else
+ return HIPRT_MAX_NORMAL_F;
+ #endif
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static float (min)() { return FLT_MIN; }
EIGEN_DEVICE_FUNC
- static float infinity() { return CUDART_INF_F; }
+ static float infinity() {
+ #if defined(EIGEN_CUDA_ARCH)
+ return CUDART_INF_F;
+ #else
+ return HIPRT_INF_F;
+ #endif
+ }
EIGEN_DEVICE_FUNC
- static float quiet_NaN() { return CUDART_NAN_F; }
+ static float quiet_NaN() {
+ #if defined(EIGEN_CUDA_ARCH)
+ return CUDART_NAN_F;
+ #else
+ return HIPRT_NAN_F;
+ #endif
+ }
};
template<> struct numeric_limits<double>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static double epsilon() { return __DBL_EPSILON__; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static double (max)() { return DBL_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static double (min)() { return DBL_MIN; }
EIGEN_DEVICE_FUNC
- static double infinity() { return CUDART_INF; }
+ static double infinity() {
+ #if defined(EIGEN_CUDA_ARCH)
+ return CUDART_INF;
+ #else
+ return HIPRT_INF;
+ #endif
+ }
EIGEN_DEVICE_FUNC
- static double quiet_NaN() { return CUDART_NAN; }
+ static double quiet_NaN() {
+ #if defined(EIGEN_CUDA_ARCH)
+ return CUDART_NAN;
+ #else
+ return HIPRT_NAN;
+ #endif
+ }
};
template<> struct numeric_limits<int>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int epsilon() { return 0; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int (max)() { return INT_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int (min)() { return INT_MIN; }
};
template<> struct numeric_limits<unsigned int>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned int epsilon() { return 0; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned int (max)() { return UINT_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned int (min)() { return 0; }
};
template<> struct numeric_limits<long>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static long epsilon() { return 0; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static long (max)() { return LONG_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static long (min)() { return LONG_MIN; }
};
template<> struct numeric_limits<unsigned long>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned long epsilon() { return 0; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned long (max)() { return ULONG_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned long (min)() { return 0; }
};
template<> struct numeric_limits<long long>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static long long epsilon() { return 0; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static long long (max)() { return LLONG_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static long long (min)() { return LLONG_MIN; }
};
template<> struct numeric_limits<unsigned long long>
{
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned long long epsilon() { return 0; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned long long (max)() { return ULLONG_MAX; }
- EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static unsigned long long (min)() { return 0; }
};
+template<> struct numeric_limits<bool>
+{
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static bool epsilon() { return false; }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static bool (max)() { return true; }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ static bool (min)() { return false; }
+};
}
-#endif
+#endif // defined(EIGEN_GPU_COMPILE_PHASE) && !EIGEN_HAS_CXX11
/** \internal
- * A base class do disable default copy ctor and copy assignement operator.
+ * A base class do disable default copy ctor and copy assignment operator.
*/
class noncopyable
{
@@ -301,13 +429,82 @@
};
/** \internal
- * Convenient struct to get the result type of a unary or binary functor.
+ * Provides access to the number of elements in the object of as a compile-time constant expression.
+ * It "returns" Eigen::Dynamic if the size cannot be resolved at compile-time (default).
*
- * It supports both the current STL mechanism (using the result_type member) as well as
- * upcoming next STL generation (using a templated result member).
- * If none of these members is provided, then the type of the first argument is returned. FIXME, that behavior is a pretty bad hack.
+ * Similar to std::tuple_size, but more general.
+ *
+ * It currently supports:
+ * - any types T defining T::SizeAtCompileTime
+ * - plain C arrays as T[N]
+ * - std::array (c++11)
+ * - some internal types such as SingleRange and AllRange
+ *
+ * The second template parameter eases SFINAE-based specializations.
*/
-#if EIGEN_HAS_STD_RESULT_OF
+template<typename T, typename EnableIf = void> struct array_size {
+ enum { value = Dynamic };
+};
+
+template<typename T> struct array_size<T,typename internal::enable_if<((T::SizeAtCompileTime&0)==0)>::type> {
+ enum { value = T::SizeAtCompileTime };
+};
+
+template<typename T, int N> struct array_size<const T (&)[N]> {
+ enum { value = N };
+};
+template<typename T, int N> struct array_size<T (&)[N]> {
+ enum { value = N };
+};
+
+#if EIGEN_HAS_CXX11
+template<typename T, std::size_t N> struct array_size<const std::array<T,N> > {
+ enum { value = N };
+};
+template<typename T, std::size_t N> struct array_size<std::array<T,N> > {
+ enum { value = N };
+};
+#endif
+
+/** \internal
+ * Analogue of the std::size free function.
+ * It returns the size of the container or view \a x of type \c T
+ *
+ * It currently supports:
+ * - any types T defining a member T::size() const
+ * - plain C arrays as T[N]
+ *
+ */
+template<typename T>
+EIGEN_CONSTEXPR Index size(const T& x) { return x.size(); }
+
+template<typename T,std::size_t N>
+EIGEN_CONSTEXPR Index size(const T (&) [N]) { return N; }
+
+/** \internal
+ * Convenient struct to get the result type of a nullary, unary, binary, or
+ * ternary functor.
+ *
+ * Pre C++11:
+ * Supports both a Func::result_type member and templated
+ * Func::result<Func(ArgTypes...)>::type member.
+ *
+ * If none of these members is provided, then the type of the first
+ * argument is returned.
+ *
+ * Post C++11:
+ * This uses std::result_of. However, note the `type` member removes
+ * const and converts references/pointers to their corresponding value type.
+ */
+#if EIGEN_HAS_STD_INVOKE_RESULT
+template<typename T> struct result_of;
+
+template<typename F, typename... ArgTypes>
+struct result_of<F(ArgTypes...)> {
+ typedef typename std::invoke_result<F, ArgTypes...>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
+#elif EIGEN_HAS_STD_RESULT_OF
template<typename T> struct result_of {
typedef typename std::result_of<T>::type type1;
typedef typename remove_all<type1>::type type;
@@ -319,6 +516,28 @@
struct has_std_result_type {int a[2];};
struct has_tr1_result {int a[3];};
+template<typename Func, int SizeOf>
+struct nullary_result_of_select {};
+
+template<typename Func>
+struct nullary_result_of_select<Func, sizeof(has_std_result_type)> {typedef typename Func::result_type type;};
+
+template<typename Func>
+struct nullary_result_of_select<Func, sizeof(has_tr1_result)> {typedef typename Func::template result<Func()>::type type;};
+
+template<typename Func>
+struct result_of<Func()> {
+ template<typename T>
+ static has_std_result_type testFunctor(T const *, typename T::result_type const * = 0);
+ template<typename T>
+ static has_tr1_result testFunctor(T const *, typename T::template result<T()>::type const * = 0);
+ static has_none testFunctor(...);
+
+ // note that the following indirection is needed for gcc-3.3
+ enum {FunctorType = sizeof(testFunctor(static_cast<Func*>(0)))};
+ typedef typename nullary_result_of_select<Func, FunctorType>::type type;
+};
+
template<typename Func, typename ArgType, int SizeOf=sizeof(has_none)>
struct unary_result_of_select {typedef typename internal::remove_all<ArgType>::type type;};
@@ -388,6 +607,45 @@
enum {FunctorType = sizeof(testFunctor(static_cast<Func*>(0)))};
typedef typename ternary_result_of_select<Func, ArgType0, ArgType1, ArgType2, FunctorType>::type type;
};
+
+#endif
+
+#if EIGEN_HAS_STD_INVOKE_RESULT
+template<typename F, typename... ArgTypes>
+struct invoke_result {
+ typedef typename std::invoke_result<F, ArgTypes...>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
+#elif EIGEN_HAS_CXX11
+template<typename F, typename... ArgTypes>
+struct invoke_result {
+ typedef typename result_of<F(ArgTypes...)>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
+#else
+template<typename F, typename ArgType0 = void, typename ArgType1 = void, typename ArgType2 = void>
+struct invoke_result {
+ typedef typename result_of<F(ArgType0, ArgType1, ArgType2)>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
+
+template<typename F>
+struct invoke_result<F, void, void, void> {
+ typedef typename result_of<F()>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
+
+template<typename F, typename ArgType0>
+struct invoke_result<F, ArgType0, void, void> {
+ typedef typename result_of<F(ArgType0)>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
+
+template<typename F, typename ArgType0, typename ArgType1>
+struct invoke_result<F, ArgType0, ArgType1, void> {
+ typedef typename result_of<F(ArgType0, ArgType1)>::type type1;
+ typedef typename remove_all<type1>::type type;
+};
#endif
struct meta_yes { char a[1]; };
@@ -397,10 +655,10 @@
template <typename T>
struct has_ReturnType
{
- template <typename C> static meta_yes testFunctor(typename C::ReturnType const *);
- template <typename C> static meta_no testFunctor(...);
+ template <typename C> static meta_yes testFunctor(C const *, typename C::ReturnType const * = 0);
+ template <typename C> static meta_no testFunctor(...);
- enum { value = sizeof(testFunctor<T>(0)) == sizeof(meta_yes) };
+ enum { value = sizeof(testFunctor<T>(static_cast<T*>(0))) == sizeof(meta_yes) };
};
template<typename T> const T* return_ptr();
@@ -457,20 +715,25 @@
/** \internal Computes the least common multiple of two positive integer A and B
- * at compile-time. It implements a naive algorithm testing all multiples of A.
- * It thus works better if A>=B.
+ * at compile-time.
*/
-template<int A, int B, int K=1, bool Done = ((A*K)%B)==0>
+template<int A, int B, int K=1, bool Done = ((A*K)%B)==0, bool Big=(A>=B)>
struct meta_least_common_multiple
{
enum { ret = meta_least_common_multiple<A,B,K+1>::ret };
};
+template<int A, int B, int K, bool Done>
+struct meta_least_common_multiple<A,B,K,Done,false>
+{
+ enum { ret = meta_least_common_multiple<B,A,K>::ret };
+};
template<int A, int B, int K>
-struct meta_least_common_multiple<A,B,K,true>
+struct meta_least_common_multiple<A,B,K,true,true>
{
enum { ret = A*K };
};
+
/** \internal determines whether the product of two numeric types is allowed and what the return type is */
template<typename T, typename U> struct scalar_product_traits
{
@@ -483,17 +746,27 @@
// typedef typename scalar_product_traits<typename remove_all<ArgType0>::type, typename remove_all<ArgType1>::type>::ReturnType type;
// };
+/** \internal Obtains a POD type suitable to use as storage for an object of a size
+ * of at most Len bytes, aligned as specified by \c Align.
+ */
+template<unsigned Len, unsigned Align>
+struct aligned_storage {
+ struct type {
+ EIGEN_ALIGN_TO_BOUNDARY(Align) unsigned char data[Len];
+ };
+};
+
} // end namespace internal
namespace numext {
-
-#if defined(__CUDA_ARCH__)
+
+#if defined(EIGEN_GPU_COMPILE_PHASE)
template<typename T> EIGEN_DEVICE_FUNC void swap(T &a, T &b) { T tmp = b; b = a; a = tmp; }
#else
template<typename T> EIGEN_STRONG_INLINE void swap(T &a, T &b) { std::swap(a,b); }
#endif
-#if defined(__CUDA_ARCH__)
+#if defined(EIGEN_GPU_COMPILE_PHASE) && !EIGEN_HAS_CXX11
using internal::device::numeric_limits;
#else
using std::numeric_limits;
@@ -502,6 +775,7 @@
// Integer division with rounding up.
// T is assumed to be an integer type with a>=0, and b>0
template<typename T>
+EIGEN_DEVICE_FUNC
T div_ceil(const T &a, const T &b)
{
return (a+b-1) / b;
@@ -509,23 +783,27 @@
// The aim of the following functions is to bypass -Wfloat-equal warnings
// when we really want a strict equality comparison on floating points.
-template<typename X, typename Y> EIGEN_STRONG_INLINE
+template<typename X, typename Y> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool equal_strict(const X& x,const Y& y) { return x == y; }
-template<> EIGEN_STRONG_INLINE
+#if !defined(EIGEN_GPU_COMPILE_PHASE) || (!defined(EIGEN_CUDA_ARCH) && defined(EIGEN_CONSTEXPR_ARE_DEVICE_FUNC))
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool equal_strict(const float& x,const float& y) { return std::equal_to<float>()(x,y); }
-template<> EIGEN_STRONG_INLINE
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool equal_strict(const double& x,const double& y) { return std::equal_to<double>()(x,y); }
+#endif
-template<typename X, typename Y> EIGEN_STRONG_INLINE
+template<typename X, typename Y> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool not_equal_strict(const X& x,const Y& y) { return x != y; }
-template<> EIGEN_STRONG_INLINE
+#if !defined(EIGEN_GPU_COMPILE_PHASE) || (!defined(EIGEN_CUDA_ARCH) && defined(EIGEN_CONSTEXPR_ARE_DEVICE_FUNC))
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool not_equal_strict(const float& x,const float& y) { return std::not_equal_to<float>()(x,y); }
-template<> EIGEN_STRONG_INLINE
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool not_equal_strict(const double& x,const double& y) { return std::not_equal_to<double>()(x,y); }
+#endif
} // end namespace numext
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReenableStupidWarnings.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReenableStupidWarnings.h
index ecc82b7..1ce6fd1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReenableStupidWarnings.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReenableStupidWarnings.h
@@ -1,4 +1,8 @@
-#ifdef EIGEN_WARNINGS_DISABLED
+#ifdef EIGEN_WARNINGS_DISABLED_2
+// "DisableStupidWarnings.h" was included twice recursively: Do not reenable warnings yet!
+# undef EIGEN_WARNINGS_DISABLED_2
+
+#elif defined(EIGEN_WARNINGS_DISABLED)
#undef EIGEN_WARNINGS_DISABLED
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReshapedHelper.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReshapedHelper.h
new file mode 100644
index 0000000..4124321
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/ReshapedHelper.h
@@ -0,0 +1,51 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+
+#ifndef EIGEN_RESHAPED_HELPER_H
+#define EIGEN_RESHAPED_HELPER_H
+
+namespace Eigen {
+
+enum AutoSize_t { AutoSize };
+const int AutoOrder = 2;
+
+namespace internal {
+
+template<typename SizeType,typename OtherSize, int TotalSize>
+struct get_compiletime_reshape_size {
+ enum { value = get_fixed_value<SizeType>::value };
+};
+
+template<typename SizeType>
+Index get_runtime_reshape_size(SizeType size, Index /*other*/, Index /*total*/) {
+ return internal::get_runtime_value(size);
+}
+
+template<typename OtherSize, int TotalSize>
+struct get_compiletime_reshape_size<AutoSize_t,OtherSize,TotalSize> {
+ enum {
+ other_size = get_fixed_value<OtherSize>::value,
+ value = (TotalSize==Dynamic || other_size==Dynamic) ? Dynamic : TotalSize / other_size };
+};
+
+inline Index get_runtime_reshape_size(AutoSize_t /*size*/, Index other, Index total) {
+ return total/other;
+}
+
+template<int Flags, int Order>
+struct get_compiletime_reshape_order {
+ enum { value = Order == AutoOrder ? Flags & RowMajorBit : Order };
+};
+
+}
+
+} // end namespace Eigen
+
+#endif // EIGEN_RESHAPED_HELPER_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/StaticAssert.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/StaticAssert.h
index 500e477..c45de59 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/StaticAssert.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/StaticAssert.h
@@ -27,7 +27,7 @@
#ifndef EIGEN_STATIC_ASSERT
#ifndef EIGEN_NO_STATIC_ASSERT
- #if EIGEN_MAX_CPP_VER>=11 && (__has_feature(cxx_static_assert) || (defined(__cplusplus) && __cplusplus >= 201103L) || (EIGEN_COMP_MSVC >= 1600))
+ #if EIGEN_MAX_CPP_VER>=11 && (__has_feature(cxx_static_assert) || (EIGEN_COMP_CXXVER >= 11) || (EIGEN_COMP_MSVC >= 1600))
// if native static_assert is enabled, let's use it
#define EIGEN_STATIC_ASSERT(X,MSG) static_assert(X,#MSG);
@@ -103,7 +103,10 @@
STORAGE_KIND_MUST_MATCH=1,
STORAGE_INDEX_MUST_MATCH=1,
CHOLMOD_SUPPORTS_DOUBLE_PRECISION_ONLY=1,
- SELFADJOINTVIEW_ACCEPTS_UPPER_AND_LOWER_MODE_ONLY=1
+ SELFADJOINTVIEW_ACCEPTS_UPPER_AND_LOWER_MODE_ONLY=1,
+ INVALID_TEMPLATE_PARAMETER=1,
+ GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS=1,
+ THE_ARRAY_SIZE_SHOULD_EQUAL_WITH_PACKET_SIZE=1
};
};
@@ -182,7 +185,7 @@
)
#define EIGEN_STATIC_ASSERT_NON_INTEGER(TYPE) \
- EIGEN_STATIC_ASSERT(!NumTraits<TYPE>::IsInteger, THIS_FUNCTION_IS_NOT_FOR_INTEGER_NUMERIC_TYPES)
+ EIGEN_STATIC_ASSERT(!Eigen::NumTraits<TYPE>::IsInteger, THIS_FUNCTION_IS_NOT_FOR_INTEGER_NUMERIC_TYPES)
// static assertion failing if it is guaranteed at compile-time that the two matrix expression types have different sizes
@@ -192,8 +195,8 @@
YOU_MIXED_MATRICES_OF_DIFFERENT_SIZES)
#define EIGEN_STATIC_ASSERT_SIZE_1x1(TYPE) \
- EIGEN_STATIC_ASSERT((TYPE::RowsAtCompileTime == 1 || TYPE::RowsAtCompileTime == Dynamic) && \
- (TYPE::ColsAtCompileTime == 1 || TYPE::ColsAtCompileTime == Dynamic), \
+ EIGEN_STATIC_ASSERT((TYPE::RowsAtCompileTime == 1 || TYPE::RowsAtCompileTime == Eigen::Dynamic) && \
+ (TYPE::ColsAtCompileTime == 1 || TYPE::ColsAtCompileTime == Eigen::Dynamic), \
THIS_METHOD_IS_ONLY_FOR_1x1_EXPRESSIONS)
#define EIGEN_STATIC_ASSERT_LVALUE(Derived) \
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/SymbolicIndex.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/SymbolicIndex.h
new file mode 100644
index 0000000..354dd9a
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/SymbolicIndex.h
@@ -0,0 +1,293 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_SYMBOLIC_INDEX_H
+#define EIGEN_SYMBOLIC_INDEX_H
+
+namespace Eigen {
+
+/** \namespace Eigen::symbolic
+ * \ingroup Core_Module
+ *
+ * This namespace defines a set of classes and functions to build and evaluate symbolic expressions of scalar type Index.
+ * Here is a simple example:
+ *
+ * \code
+ * // First step, defines symbols:
+ * struct x_tag {}; static const symbolic::SymbolExpr<x_tag> x;
+ * struct y_tag {}; static const symbolic::SymbolExpr<y_tag> y;
+ * struct z_tag {}; static const symbolic::SymbolExpr<z_tag> z;
+ *
+ * // Defines an expression:
+ * auto expr = (x+3)/y+z;
+ *
+ * // And evaluate it: (c++14)
+ * std::cout << expr.eval(x=6,y=3,z=-13) << "\n";
+ *
+ * // In c++98/11, only one symbol per expression is supported for now:
+ * auto expr98 = (3-x)/2;
+ * std::cout << expr98.eval(x=6) << "\n";
+ * \endcode
+ *
+ * It is currently only used internally to define and manipulate the Eigen::last and Eigen::lastp1 symbols in Eigen::seq and Eigen::seqN.
+ *
+ */
+namespace symbolic {
+
+template<typename Tag> class Symbol;
+template<typename Arg0> class NegateExpr;
+template<typename Arg1,typename Arg2> class AddExpr;
+template<typename Arg1,typename Arg2> class ProductExpr;
+template<typename Arg1,typename Arg2> class QuotientExpr;
+
+// A simple wrapper around an integral value to provide the eval method.
+// We could also use a free-function symbolic_eval...
+template<typename IndexType=Index>
+class ValueExpr {
+public:
+ ValueExpr(IndexType val) : m_value(val) {}
+ template<typename T>
+ IndexType eval_impl(const T&) const { return m_value; }
+protected:
+ IndexType m_value;
+};
+
+// Specialization for compile-time value,
+// It is similar to ValueExpr(N) but this version helps the compiler to generate better code.
+template<int N>
+class ValueExpr<internal::FixedInt<N> > {
+public:
+ ValueExpr() {}
+ template<typename T>
+ EIGEN_CONSTEXPR Index eval_impl(const T&) const { return N; }
+};
+
+
+/** \class BaseExpr
+ * \ingroup Core_Module
+ * Common base class of any symbolic expressions
+ */
+template<typename Derived>
+class BaseExpr
+{
+public:
+ const Derived& derived() const { return *static_cast<const Derived*>(this); }
+
+ /** Evaluate the expression given the \a values of the symbols.
+ *
+ * \param values defines the values of the symbols, it can either be a SymbolValue or a std::tuple of SymbolValue
+ * as constructed by SymbolExpr::operator= operator.
+ *
+ */
+ template<typename T>
+ Index eval(const T& values) const { return derived().eval_impl(values); }
+
+#if EIGEN_HAS_CXX14
+ template<typename... Types>
+ Index eval(Types&&... values) const { return derived().eval_impl(std::make_tuple(values...)); }
+#endif
+
+ NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); }
+
+ AddExpr<Derived,ValueExpr<> > operator+(Index b) const
+ { return AddExpr<Derived,ValueExpr<> >(derived(), b); }
+ AddExpr<Derived,ValueExpr<> > operator-(Index a) const
+ { return AddExpr<Derived,ValueExpr<> >(derived(), -a); }
+ ProductExpr<Derived,ValueExpr<> > operator*(Index a) const
+ { return ProductExpr<Derived,ValueExpr<> >(derived(),a); }
+ QuotientExpr<Derived,ValueExpr<> > operator/(Index a) const
+ { return QuotientExpr<Derived,ValueExpr<> >(derived(),a); }
+
+ friend AddExpr<Derived,ValueExpr<> > operator+(Index a, const BaseExpr& b)
+ { return AddExpr<Derived,ValueExpr<> >(b.derived(), a); }
+ friend AddExpr<NegateExpr<Derived>,ValueExpr<> > operator-(Index a, const BaseExpr& b)
+ { return AddExpr<NegateExpr<Derived>,ValueExpr<> >(-b.derived(), a); }
+ friend ProductExpr<ValueExpr<>,Derived> operator*(Index a, const BaseExpr& b)
+ { return ProductExpr<ValueExpr<>,Derived>(a,b.derived()); }
+ friend QuotientExpr<ValueExpr<>,Derived> operator/(Index a, const BaseExpr& b)
+ { return QuotientExpr<ValueExpr<>,Derived>(a,b.derived()); }
+
+ template<int N>
+ AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>) const
+ { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N>) const
+ { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
+ template<int N>
+ ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N>) const
+ { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N>) const
+ { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
+
+ template<int N>
+ friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>, const BaseExpr& b)
+ { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N>, const BaseExpr& b)
+ { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N>, const BaseExpr& b)
+ { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
+ template<int N>
+ friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N>, const BaseExpr& b)
+ { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
+
+#if (!EIGEN_HAS_CXX14)
+ template<int N>
+ AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)()) const
+ { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N> (*)()) const
+ { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
+ template<int N>
+ ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N> (*)()) const
+ { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N> (*)()) const
+ { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
+
+ template<int N>
+ friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)(), const BaseExpr& b)
+ { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N> (*)(), const BaseExpr& b)
+ { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
+ template<int N>
+ friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N> (*)(), const BaseExpr& b)
+ { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
+ template<int N>
+ friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N> (*)(), const BaseExpr& b)
+ { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
+#endif
+
+
+ template<typename OtherDerived>
+ AddExpr<Derived,OtherDerived> operator+(const BaseExpr<OtherDerived> &b) const
+ { return AddExpr<Derived,OtherDerived>(derived(), b.derived()); }
+
+ template<typename OtherDerived>
+ AddExpr<Derived,NegateExpr<OtherDerived> > operator-(const BaseExpr<OtherDerived> &b) const
+ { return AddExpr<Derived,NegateExpr<OtherDerived> >(derived(), -b.derived()); }
+
+ template<typename OtherDerived>
+ ProductExpr<Derived,OtherDerived> operator*(const BaseExpr<OtherDerived> &b) const
+ { return ProductExpr<Derived,OtherDerived>(derived(), b.derived()); }
+
+ template<typename OtherDerived>
+ QuotientExpr<Derived,OtherDerived> operator/(const BaseExpr<OtherDerived> &b) const
+ { return QuotientExpr<Derived,OtherDerived>(derived(), b.derived()); }
+};
+
+template<typename T>
+struct is_symbolic {
+ // BaseExpr has no conversion ctor, so we only have to check whether T can be statically cast to its base class BaseExpr<T>.
+ enum { value = internal::is_convertible<T,BaseExpr<T> >::value };
+};
+
+/** Represents the actual value of a symbol identified by its tag
+ *
+ * It is the return type of SymbolValue::operator=, and most of the time this is only way it is used.
+ */
+template<typename Tag>
+class SymbolValue
+{
+public:
+ /** Default constructor from the value \a val */
+ SymbolValue(Index val) : m_value(val) {}
+
+ /** \returns the stored value of the symbol */
+ Index value() const { return m_value; }
+protected:
+ Index m_value;
+};
+
+/** Expression of a symbol uniquely identified by the template parameter type \c tag */
+template<typename tag>
+class SymbolExpr : public BaseExpr<SymbolExpr<tag> >
+{
+public:
+ /** Alias to the template parameter \c tag */
+ typedef tag Tag;
+
+ SymbolExpr() {}
+
+ /** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag.
+ *
+ * The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified runtime-time value.
+ */
+ SymbolValue<Tag> operator=(Index val) const {
+ return SymbolValue<Tag>(val);
+ }
+
+ Index eval_impl(const SymbolValue<Tag> &values) const { return values.value(); }
+
+#if EIGEN_HAS_CXX14
+ // C++14 versions suitable for multiple symbols
+ template<typename... Types>
+ Index eval_impl(const std::tuple<Types...>& values) const { return std::get<SymbolValue<Tag> >(values).value(); }
+#endif
+};
+
+template<typename Arg0>
+class NegateExpr : public BaseExpr<NegateExpr<Arg0> >
+{
+public:
+ NegateExpr(const Arg0& arg0) : m_arg0(arg0) {}
+
+ template<typename T>
+ Index eval_impl(const T& values) const { return -m_arg0.eval_impl(values); }
+protected:
+ Arg0 m_arg0;
+};
+
+template<typename Arg0, typename Arg1>
+class AddExpr : public BaseExpr<AddExpr<Arg0,Arg1> >
+{
+public:
+ AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
+
+ template<typename T>
+ Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) + m_arg1.eval_impl(values); }
+protected:
+ Arg0 m_arg0;
+ Arg1 m_arg1;
+};
+
+template<typename Arg0, typename Arg1>
+class ProductExpr : public BaseExpr<ProductExpr<Arg0,Arg1> >
+{
+public:
+ ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
+
+ template<typename T>
+ Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) * m_arg1.eval_impl(values); }
+protected:
+ Arg0 m_arg0;
+ Arg1 m_arg1;
+};
+
+template<typename Arg0, typename Arg1>
+class QuotientExpr : public BaseExpr<QuotientExpr<Arg0,Arg1> >
+{
+public:
+ QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
+
+ template<typename T>
+ Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) / m_arg1.eval_impl(values); }
+protected:
+ Arg0 m_arg0;
+ Arg1 m_arg1;
+};
+
+} // end namespace symbolic
+
+} // end namespace Eigen
+
+#endif // EIGEN_SYMBOLIC_INDEX_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/XprHelper.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/XprHelper.h
index bf424a0..71c32b8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/XprHelper.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Core/util/XprHelper.h
@@ -34,6 +34,26 @@
return IndexDest(idx);
}
+// true if T can be considered as an integral index (i.e., and integral type or enum)
+template<typename T> struct is_valid_index_type
+{
+ enum { value =
+#if EIGEN_HAS_TYPE_TRAITS
+ internal::is_integral<T>::value || std::is_enum<T>::value
+#elif EIGEN_COMP_MSVC
+ internal::is_integral<T>::value || __is_enum(T)
+#else
+ // without C++11, we use is_convertible to Index instead of is_integral in order to treat enums as Index.
+ internal::is_convertible<T,Index>::value && !internal::is_same<T,float>::value && !is_same<T,double>::value
+#endif
+ };
+};
+
+// true if both types are not valid index types
+template<typename RowIndices, typename ColIndices>
+struct valid_indexed_view_overload {
+ enum { value = !(internal::is_valid_index_type<RowIndices>::value && internal::is_valid_index_type<ColIndices>::value) };
+};
// promote_scalar_arg is an helper used in operation between an expression and a scalar, like:
// expression * scalar
@@ -109,19 +129,23 @@
template<typename T, int Value> class variable_if_dynamic
{
public:
- EIGEN_EMPTY_STRUCT_CTOR(variable_if_dynamic)
+ EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(variable_if_dynamic)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit variable_if_dynamic(T v) { EIGEN_ONLY_USED_FOR_DEBUG(v); eigen_assert(v == T(Value)); }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE T value() { return T(Value); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void setValue(T) {}
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ T value() { return T(Value); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ operator T() const { return T(Value); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void setValue(T v) const { EIGEN_ONLY_USED_FOR_DEBUG(v); eigen_assert(v == T(Value)); }
};
template<typename T> class variable_if_dynamic<T, Dynamic>
{
T m_value;
- EIGEN_DEVICE_FUNC variable_if_dynamic() { eigen_assert(false); }
public:
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit variable_if_dynamic(T value) : m_value(value) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit variable_if_dynamic(T value = 0) EIGEN_NO_THROW : m_value(value) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T value() const { return m_value; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator T() const { return m_value; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void setValue(T value) { m_value = value; }
};
@@ -132,8 +156,10 @@
public:
EIGEN_EMPTY_STRUCT_CTOR(variable_if_dynamicindex)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit variable_if_dynamicindex(T v) { EIGEN_ONLY_USED_FOR_DEBUG(v); eigen_assert(v == T(Value)); }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE T value() { return T(Value); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void setValue(T) {}
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+ T value() { return T(Value); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void setValue(T) {}
};
template<typename T> class variable_if_dynamicindex<T, DynamicIndex>
@@ -158,16 +184,7 @@
template<typename T> struct packet_traits;
-template<typename T> struct unpacket_traits
-{
- typedef T type;
- typedef T half;
- enum
- {
- size = 1,
- alignment = 1
- };
-};
+template<typename T> struct unpacket_traits;
template<int Size, typename PacketType,
bool Stop = Size==Dynamic || (Size%unpacket_traits<PacketType>::size)==0 || is_same<PacketType,typename unpacket_traits<PacketType>::half>::value>
@@ -386,7 +403,7 @@
typedef Matrix<typename traits<T>::Scalar,
Rows,
Cols,
- (MaxCols==1&&MaxRows!=1) ? RowMajor : ColMajor,
+ (MaxCols==1&&MaxRows!=1) ? ColMajor : RowMajor,
MaxRows,
MaxCols
> type;
@@ -403,7 +420,7 @@
T const&,
const T
>::type type;
-
+
typedef typename conditional<
bool(traits<T>::Flags & NestByRefBit),
T &,
@@ -441,7 +458,7 @@
{
enum {
ScalarReadCost = NumTraits<typename traits<T>::Scalar>::ReadCost,
- CoeffReadCost = evaluator<T>::CoeffReadCost, // NOTE What if an evaluator evaluate itself into a tempory?
+ CoeffReadCost = evaluator<T>::CoeffReadCost, // NOTE What if an evaluator evaluate itself into a temporary?
// Then CoeffReadCost will be small (e.g., 1) but we still have to evaluate, especially if n>1.
// This situation is already taken care by the EvalBeforeNestingBit flag, which is turned ON
// for all evaluator creating a temporary. This flag is then propagated by the parent evaluators.
@@ -582,14 +599,14 @@
struct plain_row_type
{
typedef Matrix<Scalar, 1, ExpressionType::ColsAtCompileTime,
- ExpressionType::PlainObject::Options | RowMajor, 1, ExpressionType::MaxColsAtCompileTime> MatrixRowType;
+ int(ExpressionType::PlainObject::Options) | int(RowMajor), 1, ExpressionType::MaxColsAtCompileTime> MatrixRowType;
typedef Array<Scalar, 1, ExpressionType::ColsAtCompileTime,
- ExpressionType::PlainObject::Options | RowMajor, 1, ExpressionType::MaxColsAtCompileTime> ArrayRowType;
+ int(ExpressionType::PlainObject::Options) | int(RowMajor), 1, ExpressionType::MaxColsAtCompileTime> ArrayRowType;
typedef typename conditional<
is_same< typename traits<ExpressionType>::XprKind, MatrixXpr >::value,
MatrixRowType,
- ArrayRowType
+ ArrayRowType
>::type type;
};
@@ -604,7 +621,7 @@
typedef typename conditional<
is_same< typename traits<ExpressionType>::XprKind, MatrixXpr >::value,
MatrixColType,
- ArrayColType
+ ArrayColType
>::type type;
};
@@ -620,7 +637,7 @@
typedef typename conditional<
is_same< typename traits<ExpressionType>::XprKind, MatrixXpr >::value,
MatrixDiagType,
- ArrayDiagType
+ ArrayDiagType
>::type type;
};
@@ -657,24 +674,39 @@
template<typename T, int S> struct is_diagonal<DiagonalMatrix<T,S> >
{ enum { ret = true }; };
+
+template<typename T> struct is_identity
+{ enum { value = false }; };
+
+template<typename T> struct is_identity<CwiseNullaryOp<internal::scalar_identity_op<typename T::Scalar>, T> >
+{ enum { value = true }; };
+
+
template<typename S1, typename S2> struct glue_shapes;
template<> struct glue_shapes<DenseShape,TriangularShape> { typedef TriangularShape type; };
template<typename T1, typename T2>
-bool is_same_dense(const T1 &mat1, const T2 &mat2, typename enable_if<has_direct_access<T1>::ret&&has_direct_access<T2>::ret, T1>::type * = 0)
+struct possibly_same_dense {
+ enum { value = has_direct_access<T1>::ret && has_direct_access<T2>::ret && is_same<typename T1::Scalar,typename T2::Scalar>::value };
+};
+
+template<typename T1, typename T2>
+EIGEN_DEVICE_FUNC
+bool is_same_dense(const T1 &mat1, const T2 &mat2, typename enable_if<possibly_same_dense<T1,T2>::value>::type * = 0)
{
return (mat1.data()==mat2.data()) && (mat1.innerStride()==mat2.innerStride()) && (mat1.outerStride()==mat2.outerStride());
}
template<typename T1, typename T2>
-bool is_same_dense(const T1 &, const T2 &, typename enable_if<!(has_direct_access<T1>::ret&&has_direct_access<T2>::ret), T1>::type * = 0)
+EIGEN_DEVICE_FUNC
+bool is_same_dense(const T1 &, const T2 &, typename enable_if<!possibly_same_dense<T1,T2>::value>::type * = 0)
{
return false;
}
// Internal helper defining the cost of a scalar division for the type T.
// The default heuristic can be specialized for each scalar type and architecture.
-template<typename T,bool Vectorized=false,typename EnaleIf = void>
+template<typename T,bool Vectorized=false,typename EnableIf = void>
struct scalar_div_cost {
enum { value = 8*NumTraits<T>::MulCost };
};
@@ -721,7 +753,7 @@
if(f&DirectAccessBit) res += " | Direct";
if(f&NestByRefBit) res += " | NestByRef";
if(f&NoPreferredStorageOrderBit) res += " | NoPreferredStorageOrderBit";
-
+
return res;
}
#endif
@@ -818,7 +850,7 @@
#define EIGEN_CHECK_BINARY_COMPATIBILIY(BINOP,LHS,RHS) \
EIGEN_STATIC_ASSERT((Eigen::internal::has_ReturnType<ScalarBinaryOpTraits<LHS, RHS,BINOP> >::value), \
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
-
+
} // end namespace Eigen
#endif // EIGEN_XPRHELPER_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexEigenSolver.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexEigenSolver.h
index dc5fae0..081e918 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexEigenSolver.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexEigenSolver.h
@@ -214,7 +214,7 @@
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful, \c NoConvergence otherwise.
+ * \returns \c Success if computation was successful, \c NoConvergence otherwise.
*/
ComputationInfo info() const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexSchur.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexSchur.h
index 7f38919..fc71468 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexSchur.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/ComplexSchur.h
@@ -212,7 +212,7 @@
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful, \c NoConvergence otherwise.
+ * \returns \c Success if computation was successful, \c NoConvergence otherwise.
*/
ComputationInfo info() const
{
@@ -300,10 +300,13 @@
ComplexScalar trace = t.coeff(0,0) + t.coeff(1,1);
ComplexScalar eival1 = (trace + disc) / RealScalar(2);
ComplexScalar eival2 = (trace - disc) / RealScalar(2);
-
- if(numext::norm1(eival1) > numext::norm1(eival2))
+ RealScalar eival1_norm = numext::norm1(eival1);
+ RealScalar eival2_norm = numext::norm1(eival2);
+ // A division by zero can only occur if eival1==eival2==0.
+ // In this case, det==0, and all we have to do is checking that eival2_norm!=0
+ if(eival1_norm > eival2_norm)
eival2 = det / eival1;
- else
+ else if(eival2_norm!=RealScalar(0))
eival1 = det / eival2;
// choose the eigenvalue closest to the bottom entry of the diagonal
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/EigenSolver.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/EigenSolver.h
index f205b18..572b29e 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/EigenSolver.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/EigenSolver.h
@@ -110,7 +110,7 @@
*
* \sa compute() for an example.
*/
- EigenSolver() : m_eivec(), m_eivalues(), m_isInitialized(false), m_realSchur(), m_matT(), m_tmp() {}
+ EigenSolver() : m_eivec(), m_eivalues(), m_isInitialized(false), m_eigenvectorsOk(false), m_realSchur(), m_matT(), m_tmp() {}
/** \brief Default constructor with memory preallocation
*
@@ -277,7 +277,7 @@
template<typename InputType>
EigenSolver& compute(const EigenBase<InputType>& matrix, bool computeEigenvectors = true);
- /** \returns NumericalIssue if the input contains INF or NaN values or overflow occured. Returns Success otherwise. */
+ /** \returns NumericalIssue if the input contains INF or NaN values or overflow occurred. Returns Success otherwise. */
ComputationInfo info() const
{
eigen_assert(m_isInitialized && "EigenSolver is not initialized.");
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h
index 5f6bb82..d0f9091 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h
@@ -121,7 +121,7 @@
*
* \returns Reference to \c *this
*
- * Accoring to \p options, this function computes eigenvalues and (if requested)
+ * According to \p options, this function computes eigenvalues and (if requested)
* the eigenvectors of one of the following three generalized eigenproblems:
* - \c Ax_lBx: \f$ Ax = \lambda B x \f$
* - \c ABx_lx: \f$ ABx = \lambda x \f$
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/HessenbergDecomposition.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/HessenbergDecomposition.h
index f647f69..1f21139 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/HessenbergDecomposition.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/HessenbergDecomposition.h
@@ -267,7 +267,7 @@
private:
- typedef Matrix<Scalar, 1, Size, Options | RowMajor, 1, MaxSize> VectorType;
+ typedef Matrix<Scalar, 1, Size, int(Options) | int(RowMajor), 1, MaxSize> VectorType;
typedef typename NumTraits<Scalar>::Real RealScalar;
static void _compute(MatrixType& matA, CoeffVectorType& hCoeffs, VectorType& temp);
@@ -315,7 +315,7 @@
// A = A H'
matA.rightCols(remainingSize)
- .applyHouseholderOnTheRight(matA.col(i).tail(remainingSize-1).conjugate(), numext::conj(h), &temp.coeffRef(0));
+ .applyHouseholderOnTheRight(matA.col(i).tail(remainingSize-1), numext::conj(h), &temp.coeffRef(0));
}
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h
index e4e4260..66e5a3d 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h
@@ -84,7 +84,7 @@
* \sa SelfAdjointEigenSolver::eigenvalues(), MatrixBase::eigenvalues()
*/
template<typename MatrixType, unsigned int UpLo>
-inline typename SelfAdjointView<MatrixType, UpLo>::EigenvaluesReturnType
+EIGEN_DEVICE_FUNC inline typename SelfAdjointView<MatrixType, UpLo>::EigenvaluesReturnType
SelfAdjointView<MatrixType, UpLo>::eigenvalues() const
{
PlainObject thisAsMatrix(*this);
@@ -147,7 +147,7 @@
* \sa eigenvalues(), MatrixBase::operatorNorm()
*/
template<typename MatrixType, unsigned int UpLo>
-inline typename SelfAdjointView<MatrixType, UpLo>::RealScalar
+EIGEN_DEVICE_FUNC inline typename SelfAdjointView<MatrixType, UpLo>::RealScalar
SelfAdjointView<MatrixType, UpLo>::operatorNorm() const
{
return eigenvalues().cwiseAbs().maxCoeff();
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealQZ.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealQZ.h
index b3a910d..5091301 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealQZ.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealQZ.h
@@ -90,8 +90,9 @@
m_Z(size, size),
m_workspace(size*2),
m_maxIters(400),
- m_isInitialized(false)
- { }
+ m_isInitialized(false),
+ m_computeQZ(true)
+ {}
/** \brief Constructor; computes real QZ decomposition of given matrices
*
@@ -108,9 +109,11 @@
m_Z(A.rows(),A.cols()),
m_workspace(A.rows()*2),
m_maxIters(400),
- m_isInitialized(false) {
- compute(A, B, computeQZ);
- }
+ m_isInitialized(false),
+ m_computeQZ(true)
+ {
+ compute(A, B, computeQZ);
+ }
/** \brief Returns matrix Q in the QZ decomposition.
*
@@ -161,7 +164,7 @@
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful, \c NoConvergence otherwise.
+ * \returns \c Success if computation was successful, \c NoConvergence otherwise.
*/
ComputationInfo info() const
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealSchur.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealSchur.h
index 17ea903..7304ef3 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealSchur.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/RealSchur.h
@@ -190,7 +190,7 @@
RealSchur& computeFromHessenberg(const HessMatrixType& matrixH, const OrthMatrixType& matrixQ, bool computeU);
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful, \c NoConvergence otherwise.
+ * \returns \c Success if computation was successful, \c NoConvergence otherwise.
*/
ComputationInfo info() const
{
@@ -236,7 +236,7 @@
typedef Matrix<Scalar,3,1> Vector3s;
Scalar computeNormOfT();
- Index findSmallSubdiagEntry(Index iu);
+ Index findSmallSubdiagEntry(Index iu, const Scalar& considerAsZero);
void splitOffTwoRows(Index iu, bool computeU, const Scalar& exshift);
void computeShift(Index iu, Index iter, Scalar& exshift, Vector3s& shiftInfo);
void initFrancisQRStep(Index il, Index iu, const Vector3s& shiftInfo, Index& im, Vector3s& firstHouseholderVector);
@@ -270,8 +270,13 @@
// Step 1. Reduce to Hessenberg form
m_hess.compute(matrix.derived()/scale);
- // Step 2. Reduce to real Schur form
- computeFromHessenberg(m_hess.matrixH(), m_hess.matrixQ(), computeU);
+ // Step 2. Reduce to real Schur form
+ // Note: we copy m_hess.matrixQ() into m_matU here and not in computeFromHessenberg
+ // to be able to pass our working-space buffer for the Householder to Dense evaluation.
+ m_workspaceVector.resize(matrix.cols());
+ if(computeU)
+ m_hess.matrixQ().evalTo(m_matU, m_workspaceVector);
+ computeFromHessenberg(m_hess.matrixH(), m_matU, computeU);
m_matT *= scale;
@@ -284,13 +289,13 @@
using std::abs;
m_matT = matrixH;
- if(computeU)
+ m_workspaceVector.resize(m_matT.cols());
+ if(computeU && !internal::is_same_dense(m_matU,matrixQ))
m_matU = matrixQ;
Index maxIters = m_maxIters;
if (maxIters == -1)
maxIters = m_maxIterationsPerRow * matrixH.rows();
- m_workspaceVector.resize(m_matT.cols());
Scalar* workspace = &m_workspaceVector.coeffRef(0);
// The matrix m_matT is divided in three parts.
@@ -302,12 +307,16 @@
Index totalIter = 0; // iteration count for whole matrix
Scalar exshift(0); // sum of exceptional shifts
Scalar norm = computeNormOfT();
+ // sub-diagonal entries smaller than considerAsZero will be treated as zero.
+ // We use eps^2 to enable more precision in small eigenvalues.
+ Scalar considerAsZero = numext::maxi<Scalar>( norm * numext::abs2(NumTraits<Scalar>::epsilon()),
+ (std::numeric_limits<Scalar>::min)() );
if(norm!=Scalar(0))
{
while (iu >= 0)
{
- Index il = findSmallSubdiagEntry(iu);
+ Index il = findSmallSubdiagEntry(iu,considerAsZero);
// Check for convergence
if (il == iu) // One root found
@@ -364,14 +373,17 @@
/** \internal Look for single small sub-diagonal element and returns its index */
template<typename MatrixType>
-inline Index RealSchur<MatrixType>::findSmallSubdiagEntry(Index iu)
+inline Index RealSchur<MatrixType>::findSmallSubdiagEntry(Index iu, const Scalar& considerAsZero)
{
using std::abs;
Index res = iu;
while (res > 0)
{
Scalar s = abs(m_matT.coeff(res-1,res-1)) + abs(m_matT.coeff(res,res));
- if (abs(m_matT.coeff(res,res-1)) <= NumTraits<Scalar>::epsilon() * s)
+
+ s = numext::maxi<Scalar>(s * NumTraits<Scalar>::epsilon(), considerAsZero);
+
+ if (abs(m_matT.coeff(res,res-1)) <= s)
break;
res--;
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h
index 9ddd553..1469236 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h
@@ -20,7 +20,9 @@
namespace internal {
template<typename SolverType,int Size,bool IsComplex> struct direct_selfadjoint_eigenvalues;
+
template<typename MatrixType, typename DiagType, typename SubDiagType>
+EIGEN_DEVICE_FUNC
ComputationInfo computeFromTridiagonal_impl(DiagType& diag, SubDiagType& subdiag, const Index maxIterations, bool computeEigenvectors, MatrixType& eivec);
}
@@ -42,10 +44,14 @@
* \f$ v \f$ such that \f$ Av = \lambda v \f$. The eigenvalues of a
* selfadjoint matrix are always real. If \f$ D \f$ is a diagonal matrix with
* the eigenvalues on the diagonal, and \f$ V \f$ is a matrix with the
- * eigenvectors as its columns, then \f$ A = V D V^{-1} \f$ (for selfadjoint
- * matrices, the matrix \f$ V \f$ is always invertible). This is called the
+ * eigenvectors as its columns, then \f$ A = V D V^{-1} \f$. This is called the
* eigendecomposition.
*
+ * For a selfadjoint matrix, \f$ V \f$ is unitary, meaning its inverse is equal
+ * to its adjoint, \f$ V^{-1} = V^{\dagger} \f$. If \f$ A \f$ is real, then
+ * \f$ V \f$ is also real and therefore orthogonal, meaning its inverse is
+ * equal to its transpose, \f$ V^{-1} = V^T \f$.
+ *
* The algorithm exploits the fact that the matrix is selfadjoint, making it
* faster and more accurate than the general purpose eigenvalue algorithms
* implemented in EigenSolver and ComplexEigenSolver.
@@ -119,7 +125,10 @@
: m_eivec(),
m_eivalues(),
m_subdiag(),
- m_isInitialized(false)
+ m_hcoeffs(),
+ m_info(InvalidInput),
+ m_isInitialized(false),
+ m_eigenvectorsOk(false)
{ }
/** \brief Constructor, pre-allocates memory for dynamic-size matrices.
@@ -139,7 +148,9 @@
: m_eivec(size, size),
m_eivalues(size),
m_subdiag(size > 1 ? size - 1 : 1),
- m_isInitialized(false)
+ m_hcoeffs(size > 1 ? size - 1 : 1),
+ m_isInitialized(false),
+ m_eigenvectorsOk(false)
{}
/** \brief Constructor; computes eigendecomposition of given matrix.
@@ -163,7 +174,9 @@
: m_eivec(matrix.rows(), matrix.cols()),
m_eivalues(matrix.cols()),
m_subdiag(matrix.rows() > 1 ? matrix.rows() - 1 : 1),
- m_isInitialized(false)
+ m_hcoeffs(matrix.cols() > 1 ? matrix.cols() - 1 : 1),
+ m_isInitialized(false),
+ m_eigenvectorsOk(false)
{
compute(matrix.derived(), options);
}
@@ -250,6 +263,11 @@
* matrix \f$ A \f$, then the matrix returned by this function is the
* matrix \f$ V \f$ in the eigendecomposition \f$ A = V D V^{-1} \f$.
*
+ * For a selfadjoint matrix, \f$ V \f$ is unitary, meaning its inverse is equal
+ * to its adjoint, \f$ V^{-1} = V^{\dagger} \f$. If \f$ A \f$ is real, then
+ * \f$ V \f$ is also real and therefore orthogonal, meaning its inverse is
+ * equal to its transpose, \f$ V^{-1} = V^T \f$.
+ *
* Example: \include SelfAdjointEigenSolver_eigenvectors.cpp
* Output: \verbinclude SelfAdjointEigenSolver_eigenvectors.out
*
@@ -337,7 +355,7 @@
/** \brief Reports whether previous computation was successful.
*
- * \returns \c Success if computation was succesful, \c NoConvergence otherwise.
+ * \returns \c Success if computation was successful, \c NoConvergence otherwise.
*/
EIGEN_DEVICE_FUNC
ComputationInfo info() const
@@ -354,7 +372,8 @@
static const int m_maxIterations = 30;
protected:
- static void check_template_parameters()
+ static EIGEN_DEVICE_FUNC
+ void check_template_parameters()
{
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
}
@@ -362,6 +381,7 @@
EigenvectorsType m_eivec;
RealVectorType m_eivalues;
typename TridiagonalizationType::SubDiagonalType m_subdiag;
+ typename TridiagonalizationType::CoeffVectorType m_hcoeffs;
ComputationInfo m_info;
bool m_isInitialized;
bool m_eigenvectorsOk;
@@ -403,7 +423,7 @@
const InputType &matrix(a_matrix.derived());
- using std::abs;
+ EIGEN_USING_STD(abs);
eigen_assert(matrix.cols() == matrix.rows());
eigen_assert((options&~(EigVecMask|GenEigMask))==0
&& (options&EigVecMask)!=EigVecMask
@@ -434,7 +454,8 @@
if(scale==RealScalar(0)) scale = RealScalar(1);
mat.template triangularView<Lower>() /= scale;
m_subdiag.resize(n-1);
- internal::tridiagonalization_inplace(mat, diag, m_subdiag, computeEigenvectors);
+ m_hcoeffs.resize(n-1);
+ internal::tridiagonalization_inplace(mat, diag, m_subdiag, m_hcoeffs, computeEigenvectors);
m_info = internal::computeFromTridiagonal_impl(diag, m_subdiag, m_maxIterations, computeEigenvectors, m_eivec);
@@ -479,10 +500,9 @@
* \returns \c Success or \c NoConvergence
*/
template<typename MatrixType, typename DiagType, typename SubDiagType>
+EIGEN_DEVICE_FUNC
ComputationInfo computeFromTridiagonal_impl(DiagType& diag, SubDiagType& subdiag, const Index maxIterations, bool computeEigenvectors, MatrixType& eivec)
{
- using std::abs;
-
ComputationInfo info;
typedef typename MatrixType::Scalar Scalar;
@@ -493,15 +513,23 @@
typedef typename DiagType::RealScalar RealScalar;
const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
- const RealScalar precision = RealScalar(2)*NumTraits<RealScalar>::epsilon();
-
+ const RealScalar precision_inv = RealScalar(1)/NumTraits<RealScalar>::epsilon();
while (end>0)
{
- for (Index i = start; i<end; ++i)
- if (internal::isMuchSmallerThan(abs(subdiag[i]),(abs(diag[i])+abs(diag[i+1])),precision) || abs(subdiag[i]) <= considerAsZero)
- subdiag[i] = 0;
+ for (Index i = start; i<end; ++i) {
+ if (numext::abs(subdiag[i]) < considerAsZero) {
+ subdiag[i] = RealScalar(0);
+ } else {
+ // abs(subdiag[i]) <= epsilon * sqrt(abs(diag[i]) + abs(diag[i+1]))
+ // Scaled to prevent underflows.
+ const RealScalar scaled_subdiag = precision_inv * subdiag[i];
+ if (scaled_subdiag * scaled_subdiag <= (numext::abs(diag[i])+numext::abs(diag[i+1]))) {
+ subdiag[i] = RealScalar(0);
+ }
+ }
+ }
- // find the largest unreduced block
+ // find the largest unreduced block at the end of the matrix.
while (end>0 && subdiag[end-1]==RealScalar(0))
{
end--;
@@ -535,7 +563,7 @@
diag.segment(i,n-i).minCoeff(&k);
if (k > 0)
{
- std::swap(diag[i], diag[k+i]);
+ numext::swap(diag[i], diag[k+i]);
if(computeEigenvectors)
eivec.col(i).swap(eivec.col(k+i));
}
@@ -566,10 +594,10 @@
EIGEN_DEVICE_FUNC
static inline void computeRoots(const MatrixType& m, VectorType& roots)
{
- EIGEN_USING_STD_MATH(sqrt)
- EIGEN_USING_STD_MATH(atan2)
- EIGEN_USING_STD_MATH(cos)
- EIGEN_USING_STD_MATH(sin)
+ EIGEN_USING_STD(sqrt)
+ EIGEN_USING_STD(atan2)
+ EIGEN_USING_STD(cos)
+ EIGEN_USING_STD(sin)
const Scalar s_inv3 = Scalar(1)/Scalar(3);
const Scalar s_sqrt3 = sqrt(Scalar(3));
@@ -605,7 +633,8 @@
EIGEN_DEVICE_FUNC
static inline bool extract_kernel(MatrixType& mat, Ref<VectorType> res, Ref<VectorType> representative)
{
- using std::abs;
+ EIGEN_USING_STD(abs);
+ EIGEN_USING_STD(sqrt);
Index i0;
// Find non-zero column i0 (by construction, there must exist a non zero coefficient on the diagonal):
mat.diagonal().cwiseAbs().maxCoeff(&i0);
@@ -616,8 +645,8 @@
VectorType c0, c1;
n0 = (c0 = representative.cross(mat.col((i0+1)%3))).squaredNorm();
n1 = (c1 = representative.cross(mat.col((i0+2)%3))).squaredNorm();
- if(n0>n1) res = c0/std::sqrt(n0);
- else res = c1/std::sqrt(n1);
+ if(n0>n1) res = c0/sqrt(n0);
+ else res = c1/sqrt(n1);
return true;
}
@@ -719,7 +748,7 @@
EIGEN_DEVICE_FUNC
static inline void computeRoots(const MatrixType& m, VectorType& roots)
{
- using std::sqrt;
+ EIGEN_USING_STD(sqrt);
const Scalar t0 = Scalar(0.5) * sqrt( numext::abs2(m(0,0)-m(1,1)) + Scalar(4)*numext::abs2(m(1,0)));
const Scalar t1 = Scalar(0.5) * (m(0,0) + m(1,1));
roots(0) = t1 - t0;
@@ -729,8 +758,8 @@
EIGEN_DEVICE_FUNC
static inline void run(SolverType& solver, const MatrixType& mat, int options)
{
- EIGEN_USING_STD_MATH(sqrt);
- EIGEN_USING_STD_MATH(abs);
+ EIGEN_USING_STD(sqrt);
+ EIGEN_USING_STD(abs);
eigen_assert(mat.cols() == 2 && mat.cols() == mat.rows());
eigen_assert((options&~(EigVecMask|GenEigMask))==0
@@ -803,32 +832,38 @@
}
namespace internal {
+
+// Francis implicit QR step.
template<int StorageOrder,typename RealScalar, typename Scalar, typename Index>
EIGEN_DEVICE_FUNC
static void tridiagonal_qr_step(RealScalar* diag, RealScalar* subdiag, Index start, Index end, Scalar* matrixQ, Index n)
{
- using std::abs;
+ // Wilkinson Shift.
RealScalar td = (diag[end-1] - diag[end])*RealScalar(0.5);
RealScalar e = subdiag[end-1];
// Note that thanks to scaling, e^2 or td^2 cannot overflow, however they can still
// underflow thus leading to inf/NaN values when using the following commented code:
-// RealScalar e2 = numext::abs2(subdiag[end-1]);
-// RealScalar mu = diag[end] - e2 / (td + (td>0 ? 1 : -1) * sqrt(td*td + e2));
+ // RealScalar e2 = numext::abs2(subdiag[end-1]);
+ // RealScalar mu = diag[end] - e2 / (td + (td>0 ? 1 : -1) * sqrt(td*td + e2));
// This explain the following, somewhat more complicated, version:
RealScalar mu = diag[end];
- if(td==RealScalar(0))
- mu -= abs(e);
- else
- {
- RealScalar e2 = numext::abs2(subdiag[end-1]);
- RealScalar h = numext::hypot(td,e);
- if(e2==RealScalar(0)) mu -= (e / (td + (td>RealScalar(0) ? RealScalar(1) : RealScalar(-1)))) * (e / h);
- else mu -= e2 / (td + (td>RealScalar(0) ? h : -h));
+ if(td==RealScalar(0)) {
+ mu -= numext::abs(e);
+ } else if (e != RealScalar(0)) {
+ const RealScalar e2 = numext::abs2(e);
+ const RealScalar h = numext::hypot(td,e);
+ if(e2 == RealScalar(0)) {
+ mu -= e / ((td + (td>RealScalar(0) ? h : -h)) / e);
+ } else {
+ mu -= e2 / (td + (td>RealScalar(0) ? h : -h));
+ }
}
-
+
RealScalar x = diag[start] - mu;
RealScalar z = subdiag[start];
- for (Index k = start; k < end; ++k)
+ // If z ever becomes zero, the Givens rotation will be the identity and
+ // z will stay zero for all future iterations.
+ for (Index k = start; k < end && z != RealScalar(0); ++k)
{
JacobiRotation<RealScalar> rot;
rot.makeGivens(x, z);
@@ -841,12 +876,11 @@
diag[k+1] = rot.s() * sdk + rot.c() * dkp1;
subdiag[k] = rot.c() * sdk - rot.s() * dkp1;
-
if (k > start)
subdiag[k - 1] = rot.c() * subdiag[k-1] - rot.s() * z;
+ // "Chasing the bulge" to return to triangular form.
x = subdiag[k];
-
if (k < end - 1)
{
z = -rot.s() * subdiag[k+1];
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/Tridiagonalization.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/Tridiagonalization.h
index 1d102c1..674c92a 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/Tridiagonalization.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Eigenvalues/Tridiagonalization.h
@@ -11,10 +11,10 @@
#ifndef EIGEN_TRIDIAGONALIZATION_H
#define EIGEN_TRIDIAGONALIZATION_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
-
+
template<typename MatrixType> struct TridiagonalizationMatrixTReturnType;
template<typename MatrixType>
struct traits<TridiagonalizationMatrixTReturnType<MatrixType> >
@@ -25,6 +25,7 @@
};
template<typename MatrixType, typename CoeffVectorType>
+EIGEN_DEVICE_FUNC
void tridiagonalization_inplace(MatrixType& matA, CoeffVectorType& hCoeffs);
}
@@ -344,6 +345,7 @@
* \sa Tridiagonalization::packedMatrix()
*/
template<typename MatrixType, typename CoeffVectorType>
+EIGEN_DEVICE_FUNC
void tridiagonalization_inplace(MatrixType& matA, CoeffVectorType& hCoeffs)
{
using numext::conj;
@@ -352,7 +354,7 @@
Index n = matA.rows();
eigen_assert(n==matA.cols());
eigen_assert(n==hCoeffs.size()+1 || n==1);
-
+
for (Index i = 0; i<n-1; ++i)
{
Index remainingSize = n-i-1;
@@ -423,11 +425,13 @@
*
* \sa class Tridiagonalization
*/
-template<typename MatrixType, typename DiagonalType, typename SubDiagonalType>
-void tridiagonalization_inplace(MatrixType& mat, DiagonalType& diag, SubDiagonalType& subdiag, bool extractQ)
+template<typename MatrixType, typename DiagonalType, typename SubDiagonalType, typename CoeffVectorType>
+EIGEN_DEVICE_FUNC
+void tridiagonalization_inplace(MatrixType& mat, DiagonalType& diag, SubDiagonalType& subdiag,
+ CoeffVectorType& hcoeffs, bool extractQ)
{
eigen_assert(mat.cols()==mat.rows() && diag.size()==mat.rows() && subdiag.size()==mat.rows()-1);
- tridiagonalization_inplace_selector<MatrixType>::run(mat, diag, subdiag, extractQ);
+ tridiagonalization_inplace_selector<MatrixType>::run(mat, diag, subdiag, hcoeffs, extractQ);
}
/** \internal
@@ -439,10 +443,10 @@
typedef typename Tridiagonalization<MatrixType>::CoeffVectorType CoeffVectorType;
typedef typename Tridiagonalization<MatrixType>::HouseholderSequenceType HouseholderSequenceType;
template<typename DiagonalType, typename SubDiagonalType>
- static void run(MatrixType& mat, DiagonalType& diag, SubDiagonalType& subdiag, bool extractQ)
+ static EIGEN_DEVICE_FUNC
+ void run(MatrixType& mat, DiagonalType& diag, SubDiagonalType& subdiag, CoeffVectorType& hCoeffs, bool extractQ)
{
- CoeffVectorType hCoeffs(mat.cols()-1);
- tridiagonalization_inplace(mat,hCoeffs);
+ tridiagonalization_inplace(mat, hCoeffs);
diag = mat.diagonal().real();
subdiag = mat.template diagonal<-1>().real();
if(extractQ)
@@ -462,8 +466,8 @@
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
- template<typename DiagonalType, typename SubDiagonalType>
- static void run(MatrixType& mat, DiagonalType& diag, SubDiagonalType& subdiag, bool extractQ)
+ template<typename DiagonalType, typename SubDiagonalType, typename CoeffVectorType>
+ static void run(MatrixType& mat, DiagonalType& diag, SubDiagonalType& subdiag, CoeffVectorType&, bool extractQ)
{
using std::sqrt;
const RealScalar tol = (std::numeric_limits<RealScalar>::min)();
@@ -507,8 +511,9 @@
{
typedef typename MatrixType::Scalar Scalar;
- template<typename DiagonalType, typename SubDiagonalType>
- static void run(MatrixType& mat, DiagonalType& diag, SubDiagonalType&, bool extractQ)
+ template<typename DiagonalType, typename SubDiagonalType, typename CoeffVectorType>
+ static EIGEN_DEVICE_FUNC
+ void run(MatrixType& mat, DiagonalType& diag, SubDiagonalType&, CoeffVectorType&, bool extractQ)
{
diag(0,0) = numext::real(mat(0,0));
if(extractQ)
@@ -542,8 +547,8 @@
result.template diagonal<-1>() = m_matrix.template diagonal<-1>();
}
- Index rows() const { return m_matrix.rows(); }
- Index cols() const { return m_matrix.cols(); }
+ EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_matrix.rows(); }
+ EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_matrix.cols(); }
protected:
typename MatrixType::Nested m_matrix;
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/BlockHouseholder.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/BlockHouseholder.h
index 01a7ed1..39ce1c2 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/BlockHouseholder.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/BlockHouseholder.h
@@ -63,8 +63,15 @@
triFactor.row(i).tail(rt).noalias() = -hCoeffs(i) * vectors.col(i).tail(rs).adjoint()
* vectors.bottomRightCorner(rs, rt).template triangularView<UnitLower>();
- // FIXME add .noalias() once the triangular product can work inplace
- triFactor.row(i).tail(rt) = triFactor.row(i).tail(rt) * triFactor.bottomRightCorner(rt,rt).template triangularView<Upper>();
+ // FIXME use the following line with .noalias() once the triangular product can work inplace
+ // triFactor.row(i).tail(rt) = triFactor.row(i).tail(rt) * triFactor.bottomRightCorner(rt,rt).template triangularView<Upper>();
+ for(Index j=nbVecs-1; j>i; --j)
+ {
+ typename TriangularFactorType::Scalar z = triFactor(i,j);
+ triFactor(i,j) = z * triFactor(j,j);
+ if(nbVecs-j-1>0)
+ triFactor.row(i).tail(nbVecs-j-1) += z * triFactor.row(j).tail(nbVecs-j-1);
+ }
}
triFactor(i,i) = hCoeffs(i);
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/Householder.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/Householder.h
index 80de2c3..5bc037f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/Householder.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/Householder.h
@@ -39,6 +39,7 @@
* MatrixBase::applyHouseholderOnTheRight()
*/
template<typename Derived>
+EIGEN_DEVICE_FUNC
void MatrixBase<Derived>::makeHouseholderInPlace(Scalar& tau, RealScalar& beta)
{
VectorBlock<Derived, internal::decrement_size<Base::SizeAtCompileTime>::ret> essentialPart(derived(), 1, size()-1);
@@ -62,6 +63,7 @@
*/
template<typename Derived>
template<typename EssentialPart>
+EIGEN_DEVICE_FUNC
void MatrixBase<Derived>::makeHouseholder(
EssentialPart& essential,
Scalar& tau,
@@ -103,13 +105,14 @@
* \param essential the essential part of the vector \c v
* \param tau the scaling factor of the Householder transformation
* \param workspace a pointer to working space with at least
- * this->cols() * essential.size() entries
+ * this->cols() entries
*
* \sa MatrixBase::makeHouseholder(), MatrixBase::makeHouseholderInPlace(),
* MatrixBase::applyHouseholderOnTheRight()
*/
template<typename Derived>
template<typename EssentialPart>
+EIGEN_DEVICE_FUNC
void MatrixBase<Derived>::applyHouseholderOnTheLeft(
const EssentialPart& essential,
const Scalar& tau,
@@ -140,13 +143,14 @@
* \param essential the essential part of the vector \c v
* \param tau the scaling factor of the Householder transformation
* \param workspace a pointer to working space with at least
- * this->cols() * essential.size() entries
+ * this->rows() entries
*
* \sa MatrixBase::makeHouseholder(), MatrixBase::makeHouseholderInPlace(),
* MatrixBase::applyHouseholderOnTheLeft()
*/
template<typename Derived>
template<typename EssentialPart>
+EIGEN_DEVICE_FUNC
void MatrixBase<Derived>::applyHouseholderOnTheRight(
const EssentialPart& essential,
const Scalar& tau,
@@ -160,10 +164,10 @@
{
Map<typename internal::plain_col_type<PlainObject>::type> tmp(workspace,rows());
Block<Derived, Derived::RowsAtCompileTime, EssentialPart::SizeAtCompileTime> right(derived(), 0, 1, rows(), cols()-1);
- tmp.noalias() = right * essential.conjugate();
+ tmp.noalias() = right * essential;
tmp += this->col(0);
this->col(0) -= tau * tmp;
- right.noalias() -= tau * tmp * essential.transpose();
+ right.noalias() -= tau * tmp * essential.adjoint();
}
}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/HouseholderSequence.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/HouseholderSequence.h
index 3ce0a69..022f6c3 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/HouseholderSequence.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Householder/HouseholderSequence.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_HOUSEHOLDER_SEQUENCE_H
#define EIGEN_HOUSEHOLDER_SEQUENCE_H
-namespace Eigen {
+namespace Eigen {
/** \ingroup Householder_Module
* \householder_module
@@ -34,8 +34,8 @@
* form \f$ H = \prod_{i=0}^{n-1} H_i \f$ where the i-th Householder reflection is \f$ H_i = I - h_i v_i
* v_i^* \f$. The i-th Householder coefficient \f$ h_i \f$ is a scalar and the i-th Householder vector \f$
* v_i \f$ is a vector of the form
- * \f[
- * v_i = [\underbrace{0, \ldots, 0}_{i-1\mbox{ zeros}}, 1, \underbrace{*, \ldots,*}_{n-i\mbox{ arbitrary entries}} ].
+ * \f[
+ * v_i = [\underbrace{0, \ldots, 0}_{i-1\mbox{ zeros}}, 1, \underbrace{*, \ldots,*}_{n-i\mbox{ arbitrary entries}} ].
* \f]
* The last \f$ n-i \f$ entries of \f$ v_i \f$ are called the essential part of the Householder vector.
*
@@ -87,7 +87,7 @@
{
typedef Block<const VectorsType, Dynamic, 1> EssentialVectorType;
typedef HouseholderSequence<VectorsType, CoeffsType, OnTheLeft> HouseholderSequenceType;
- static inline const EssentialVectorType essentialVector(const HouseholderSequenceType& h, Index k)
+ static EIGEN_DEVICE_FUNC inline const EssentialVectorType essentialVector(const HouseholderSequenceType& h, Index k)
{
Index start = k+1+h.m_shift;
return Block<const VectorsType,Dynamic,1>(h.m_vectors, start, k, h.rows()-start, 1);
@@ -120,7 +120,7 @@
: public EigenBase<HouseholderSequence<VectorsType,CoeffsType,Side> >
{
typedef typename internal::hseq_side_dependent_impl<VectorsType,CoeffsType,Side>::EssentialVectorType EssentialVectorType;
-
+
public:
enum {
RowsAtCompileTime = internal::traits<HouseholderSequence>::RowsAtCompileTime,
@@ -140,6 +140,28 @@
Side
> ConjugateReturnType;
+ typedef HouseholderSequence<
+ VectorsType,
+ typename internal::conditional<NumTraits<Scalar>::IsComplex,
+ typename internal::remove_all<typename CoeffsType::ConjugateReturnType>::type,
+ CoeffsType>::type,
+ Side
+ > AdjointReturnType;
+
+ typedef HouseholderSequence<
+ typename internal::conditional<NumTraits<Scalar>::IsComplex,
+ typename internal::remove_all<typename VectorsType::ConjugateReturnType>::type,
+ VectorsType>::type,
+ CoeffsType,
+ Side
+ > TransposeReturnType;
+
+ typedef HouseholderSequence<
+ typename internal::add_const<VectorsType>::type,
+ typename internal::add_const<CoeffsType>::type,
+ Side
+ > ConstHouseholderSequence;
+
/** \brief Constructor.
* \param[in] v %Matrix containing the essential parts of the Householder vectors
* \param[in] h Vector containing the Householder coefficients
@@ -157,33 +179,37 @@
*
* \sa setLength(), setShift()
*/
+ EIGEN_DEVICE_FUNC
HouseholderSequence(const VectorsType& v, const CoeffsType& h)
- : m_vectors(v), m_coeffs(h), m_trans(false), m_length(v.diagonalSize()),
+ : m_vectors(v), m_coeffs(h), m_reverse(false), m_length(v.diagonalSize()),
m_shift(0)
{
}
/** \brief Copy constructor. */
+ EIGEN_DEVICE_FUNC
HouseholderSequence(const HouseholderSequence& other)
: m_vectors(other.m_vectors),
m_coeffs(other.m_coeffs),
- m_trans(other.m_trans),
+ m_reverse(other.m_reverse),
m_length(other.m_length),
m_shift(other.m_shift)
{
}
/** \brief Number of rows of transformation viewed as a matrix.
- * \returns Number of rows
+ * \returns Number of rows
* \details This equals the dimension of the space that the transformation acts on.
*/
- Index rows() const { return Side==OnTheLeft ? m_vectors.rows() : m_vectors.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index rows() const EIGEN_NOEXCEPT { return Side==OnTheLeft ? m_vectors.rows() : m_vectors.cols(); }
/** \brief Number of columns of transformation viewed as a matrix.
* \returns Number of columns
* \details This equals the dimension of the space that the transformation acts on.
*/
- Index cols() const { return rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ Index cols() const EIGEN_NOEXCEPT { return rows(); }
/** \brief Essential part of a Householder vector.
* \param[in] k Index of Householder reflection
@@ -191,14 +217,15 @@
*
* This function returns the essential part of the Householder vector \f$ v_i \f$. This is a vector of
* length \f$ n-i \f$ containing the last \f$ n-i \f$ entries of the vector
- * \f[
- * v_i = [\underbrace{0, \ldots, 0}_{i-1\mbox{ zeros}}, 1, \underbrace{*, \ldots,*}_{n-i\mbox{ arbitrary entries}} ].
+ * \f[
+ * v_i = [\underbrace{0, \ldots, 0}_{i-1\mbox{ zeros}}, 1, \underbrace{*, \ldots,*}_{n-i\mbox{ arbitrary entries}} ].
* \f]
* The index \f$ i \f$ equals \p k + shift(), corresponding to the k-th column of the matrix \p v
* passed to the constructor.
*
* \sa setShift(), shift()
*/
+ EIGEN_DEVICE_FUNC
const EssentialVectorType essentialVector(Index k) const
{
eigen_assert(k >= 0 && k < m_length);
@@ -206,31 +233,51 @@
}
/** \brief %Transpose of the Householder sequence. */
- HouseholderSequence transpose() const
+ TransposeReturnType transpose() const
{
- return HouseholderSequence(*this).setTrans(!m_trans);
+ return TransposeReturnType(m_vectors.conjugate(), m_coeffs)
+ .setReverseFlag(!m_reverse)
+ .setLength(m_length)
+ .setShift(m_shift);
}
/** \brief Complex conjugate of the Householder sequence. */
ConjugateReturnType conjugate() const
{
return ConjugateReturnType(m_vectors.conjugate(), m_coeffs.conjugate())
- .setTrans(m_trans)
+ .setReverseFlag(m_reverse)
.setLength(m_length)
.setShift(m_shift);
}
- /** \brief Adjoint (conjugate transpose) of the Householder sequence. */
- ConjugateReturnType adjoint() const
+ /** \returns an expression of the complex conjugate of \c *this if Cond==true,
+ * returns \c *this otherwise.
+ */
+ template<bool Cond>
+ EIGEN_DEVICE_FUNC
+ inline typename internal::conditional<Cond,ConjugateReturnType,ConstHouseholderSequence>::type
+ conjugateIf() const
{
- return conjugate().setTrans(!m_trans);
+ typedef typename internal::conditional<Cond,ConjugateReturnType,ConstHouseholderSequence>::type ReturnType;
+ return ReturnType(m_vectors.template conjugateIf<Cond>(), m_coeffs.template conjugateIf<Cond>());
+ }
+
+ /** \brief Adjoint (conjugate transpose) of the Householder sequence. */
+ AdjointReturnType adjoint() const
+ {
+ return AdjointReturnType(m_vectors, m_coeffs.conjugate())
+ .setReverseFlag(!m_reverse)
+ .setLength(m_length)
+ .setShift(m_shift);
}
/** \brief Inverse of the Householder sequence (equals the adjoint). */
- ConjugateReturnType inverse() const { return adjoint(); }
+ AdjointReturnType inverse() const { return adjoint(); }
/** \internal */
- template<typename DestType> inline void evalTo(DestType& dst) const
+ template<typename DestType>
+ inline EIGEN_DEVICE_FUNC
+ void evalTo(DestType& dst) const
{
Matrix<Scalar, DestType::RowsAtCompileTime, 1,
AutoAlign|ColMajor, DestType::MaxRowsAtCompileTime, 1> workspace(rows());
@@ -239,6 +286,7 @@
/** \internal */
template<typename Dest, typename Workspace>
+ EIGEN_DEVICE_FUNC
void evalTo(Dest& dst, Workspace& workspace) const
{
workspace.resize(rows());
@@ -251,7 +299,7 @@
for(Index k = vecs-1; k >= 0; --k)
{
Index cornerSize = rows() - k - m_shift;
- if(m_trans)
+ if(m_reverse)
dst.bottomRightCorner(cornerSize, cornerSize)
.applyHouseholderOnTheRight(essentialVector(k), m_coeffs.coeff(k), workspace.data());
else
@@ -265,18 +313,26 @@
for(Index k = 0; k<cols()-vecs ; ++k)
dst.col(k).tail(rows()-k-1).setZero();
}
+ else if(m_length>BlockSize)
+ {
+ dst.setIdentity(rows(), rows());
+ if(m_reverse)
+ applyThisOnTheLeft(dst,workspace,true);
+ else
+ applyThisOnTheLeft(dst,workspace,true);
+ }
else
{
dst.setIdentity(rows(), rows());
for(Index k = vecs-1; k >= 0; --k)
{
Index cornerSize = rows() - k - m_shift;
- if(m_trans)
+ if(m_reverse)
dst.bottomRightCorner(cornerSize, cornerSize)
- .applyHouseholderOnTheRight(essentialVector(k), m_coeffs.coeff(k), &workspace.coeffRef(0));
+ .applyHouseholderOnTheRight(essentialVector(k), m_coeffs.coeff(k), workspace.data());
else
dst.bottomRightCorner(cornerSize, cornerSize)
- .applyHouseholderOnTheLeft(essentialVector(k), m_coeffs.coeff(k), &workspace.coeffRef(0));
+ .applyHouseholderOnTheLeft(essentialVector(k), m_coeffs.coeff(k), workspace.data());
}
}
}
@@ -295,42 +351,52 @@
workspace.resize(dst.rows());
for(Index k = 0; k < m_length; ++k)
{
- Index actual_k = m_trans ? m_length-k-1 : k;
+ Index actual_k = m_reverse ? m_length-k-1 : k;
dst.rightCols(rows()-m_shift-actual_k)
.applyHouseholderOnTheRight(essentialVector(actual_k), m_coeffs.coeff(actual_k), workspace.data());
}
}
/** \internal */
- template<typename Dest> inline void applyThisOnTheLeft(Dest& dst) const
+ template<typename Dest> inline void applyThisOnTheLeft(Dest& dst, bool inputIsIdentity = false) const
{
Matrix<Scalar,1,Dest::ColsAtCompileTime,RowMajor,1,Dest::MaxColsAtCompileTime> workspace;
- applyThisOnTheLeft(dst, workspace);
+ applyThisOnTheLeft(dst, workspace, inputIsIdentity);
}
/** \internal */
template<typename Dest, typename Workspace>
- inline void applyThisOnTheLeft(Dest& dst, Workspace& workspace) const
+ inline void applyThisOnTheLeft(Dest& dst, Workspace& workspace, bool inputIsIdentity = false) const
{
- const Index BlockSize = 48;
+ if(inputIsIdentity && m_reverse)
+ inputIsIdentity = false;
// if the entries are large enough, then apply the reflectors by block
if(m_length>=BlockSize && dst.cols()>1)
{
- for(Index i = 0; i < m_length; i+=BlockSize)
+ // Make sure we have at least 2 useful blocks, otherwise it is point-less:
+ Index blockSize = m_length<Index(2*BlockSize) ? (m_length+1)/2 : Index(BlockSize);
+ for(Index i = 0; i < m_length; i+=blockSize)
{
- Index end = m_trans ? (std::min)(m_length,i+BlockSize) : m_length-i;
- Index k = m_trans ? i : (std::max)(Index(0),end-BlockSize);
+ Index end = m_reverse ? (std::min)(m_length,i+blockSize) : m_length-i;
+ Index k = m_reverse ? i : (std::max)(Index(0),end-blockSize);
Index bs = end-k;
Index start = k + m_shift;
-
+
typedef Block<typename internal::remove_all<VectorsType>::type,Dynamic,Dynamic> SubVectorsType;
SubVectorsType sub_vecs1(m_vectors.const_cast_derived(), Side==OnTheRight ? k : start,
Side==OnTheRight ? start : k,
Side==OnTheRight ? bs : m_vectors.rows()-start,
Side==OnTheRight ? m_vectors.cols()-start : bs);
typename internal::conditional<Side==OnTheRight, Transpose<SubVectorsType>, SubVectorsType&>::type sub_vecs(sub_vecs1);
- Block<Dest,Dynamic,Dynamic> sub_dst(dst,dst.rows()-rows()+m_shift+k,0, rows()-m_shift-k,dst.cols());
- apply_block_householder_on_the_left(sub_dst, sub_vecs, m_coeffs.segment(k, bs), !m_trans);
+
+ Index dstStart = dst.rows()-rows()+m_shift+k;
+ Index dstRows = rows()-m_shift-k;
+ Block<Dest,Dynamic,Dynamic> sub_dst(dst,
+ dstStart,
+ inputIsIdentity ? dstStart : 0,
+ dstRows,
+ inputIsIdentity ? dstRows : dst.cols());
+ apply_block_householder_on_the_left(sub_dst, sub_vecs, m_coeffs.segment(k, bs), !m_reverse);
}
}
else
@@ -338,8 +404,9 @@
workspace.resize(dst.cols());
for(Index k = 0; k < m_length; ++k)
{
- Index actual_k = m_trans ? k : m_length-k-1;
- dst.bottomRows(rows()-m_shift-actual_k)
+ Index actual_k = m_reverse ? k : m_length-k-1;
+ Index dstStart = rows()-m_shift-actual_k;
+ dst.bottomRightCorner(dstStart, inputIsIdentity ? dstStart : dst.cols())
.applyHouseholderOnTheLeft(essentialVector(actual_k), m_coeffs.coeff(actual_k), workspace.data());
}
}
@@ -357,7 +424,7 @@
{
typename internal::matrix_type_times_scalar_type<Scalar, OtherDerived>::Type
res(other.template cast<typename internal::matrix_type_times_scalar_type<Scalar,OtherDerived>::ResultScalar>());
- applyThisOnTheLeft(res);
+ applyThisOnTheLeft(res, internal::is_identity<OtherDerived>::value && res.rows()==res.cols());
return res;
}
@@ -372,6 +439,7 @@
*
* \sa length()
*/
+ EIGEN_DEVICE_FUNC
HouseholderSequence& setLength(Index length)
{
m_length = length;
@@ -389,13 +457,17 @@
*
* \sa shift()
*/
+ EIGEN_DEVICE_FUNC
HouseholderSequence& setShift(Index shift)
{
m_shift = shift;
return *this;
}
+ EIGEN_DEVICE_FUNC
Index length() const { return m_length; } /**< \brief Returns the length of the Householder sequence. */
+
+ EIGEN_DEVICE_FUNC
Index shift() const { return m_shift; } /**< \brief Returns the shift of the Householder sequence. */
/* Necessary for .adjoint() and .conjugate() */
@@ -403,27 +475,30 @@
protected:
- /** \brief Sets the transpose flag.
- * \param [in] trans New value of the transpose flag.
+ /** \internal
+ * \brief Sets the reverse flag.
+ * \param [in] reverse New value of the reverse flag.
*
- * By default, the transpose flag is not set. If the transpose flag is set, then this object represents
- * \f$ H^T = H_{n-1}^T \ldots H_1^T H_0^T \f$ instead of \f$ H = H_0 H_1 \ldots H_{n-1} \f$.
+ * By default, the reverse flag is not set. If the reverse flag is set, then this object represents
+ * \f$ H^r = H_{n-1} \ldots H_1 H_0 \f$ instead of \f$ H = H_0 H_1 \ldots H_{n-1} \f$.
+ * \note For real valued HouseholderSequence this is equivalent to transposing \f$ H \f$.
*
- * \sa trans()
+ * \sa reverseFlag(), transpose(), adjoint()
*/
- HouseholderSequence& setTrans(bool trans)
+ HouseholderSequence& setReverseFlag(bool reverse)
{
- m_trans = trans;
+ m_reverse = reverse;
return *this;
}
- bool trans() const { return m_trans; } /**< \brief Returns the transpose flag. */
+ bool reverseFlag() const { return m_reverse; } /**< \internal \brief Returns the reverse flag. */
typename VectorsType::Nested m_vectors;
typename CoeffsType::Nested m_coeffs;
- bool m_trans;
+ bool m_reverse;
Index m_length;
Index m_shift;
+ enum { BlockSize = 48 };
};
/** \brief Computes the product of a matrix with a Householder sequence.
@@ -444,7 +519,7 @@
}
/** \ingroup Householder_Module \householder_module
- * \brief Convenience function for constructing a Householder sequence.
+ * \brief Convenience function for constructing a Householder sequence.
* \returns A HouseholderSequence constructed from the specified arguments.
*/
template<typename VectorsType, typename CoeffsType>
@@ -454,7 +529,7 @@
}
/** \ingroup Householder_Module \householder_module
- * \brief Convenience function for constructing a Householder sequence.
+ * \brief Convenience function for constructing a Householder sequence.
* \returns A HouseholderSequence constructed from the specified arguments.
* \details This function differs from householderSequence() in that the template argument \p OnTheSide of
* the constructed HouseholderSequence is set to OnTheRight, instead of the default OnTheLeft.
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/Jacobi/Jacobi.h b/wpimath/src/main/native/eigeninclude/Eigen/src/Jacobi/Jacobi.h
index 1998c63..76668a5 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/Jacobi/Jacobi.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/Jacobi/Jacobi.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_JACOBI_H
#define EIGEN_JACOBI_H
-namespace Eigen {
+namespace Eigen {
/** \ingroup Jacobi_Module
* \jacobi_module
@@ -37,17 +37,20 @@
typedef typename NumTraits<Scalar>::Real RealScalar;
/** Default constructor without any initialization. */
+ EIGEN_DEVICE_FUNC
JacobiRotation() {}
/** Construct a planar rotation from a cosine-sine pair (\a c, \c s). */
+ EIGEN_DEVICE_FUNC
JacobiRotation(const Scalar& c, const Scalar& s) : m_c(c), m_s(s) {}
- Scalar& c() { return m_c; }
- Scalar c() const { return m_c; }
- Scalar& s() { return m_s; }
- Scalar s() const { return m_s; }
+ EIGEN_DEVICE_FUNC Scalar& c() { return m_c; }
+ EIGEN_DEVICE_FUNC Scalar c() const { return m_c; }
+ EIGEN_DEVICE_FUNC Scalar& s() { return m_s; }
+ EIGEN_DEVICE_FUNC Scalar s() const { return m_s; }
/** Concatenates two planar rotation */
+ EIGEN_DEVICE_FUNC
JacobiRotation operator*(const JacobiRotation& other)
{
using numext::conj;
@@ -56,19 +59,26 @@
}
/** Returns the transposed transformation */
+ EIGEN_DEVICE_FUNC
JacobiRotation transpose() const { using numext::conj; return JacobiRotation(m_c, -conj(m_s)); }
/** Returns the adjoint transformation */
+ EIGEN_DEVICE_FUNC
JacobiRotation adjoint() const { using numext::conj; return JacobiRotation(conj(m_c), -m_s); }
template<typename Derived>
+ EIGEN_DEVICE_FUNC
bool makeJacobi(const MatrixBase<Derived>&, Index p, Index q);
+ EIGEN_DEVICE_FUNC
bool makeJacobi(const RealScalar& x, const Scalar& y, const RealScalar& z);
+ EIGEN_DEVICE_FUNC
void makeGivens(const Scalar& p, const Scalar& q, Scalar* r=0);
protected:
+ EIGEN_DEVICE_FUNC
void makeGivens(const Scalar& p, const Scalar& q, Scalar* r, internal::true_type);
+ EIGEN_DEVICE_FUNC
void makeGivens(const Scalar& p, const Scalar& q, Scalar* r, internal::false_type);
Scalar m_c, m_s;
@@ -80,10 +90,12 @@
* \sa MatrixBase::makeJacobi(const MatrixBase<Derived>&, Index, Index), MatrixBase::applyOnTheLeft(), MatrixBase::applyOnTheRight()
*/
template<typename Scalar>
+EIGEN_DEVICE_FUNC
bool JacobiRotation<Scalar>::makeJacobi(const RealScalar& x, const Scalar& y, const RealScalar& z)
{
using std::sqrt;
using std::abs;
+
RealScalar deno = RealScalar(2)*abs(y);
if(deno < (std::numeric_limits<RealScalar>::min)())
{
@@ -123,6 +135,7 @@
*/
template<typename Scalar>
template<typename Derived>
+EIGEN_DEVICE_FUNC
inline bool JacobiRotation<Scalar>::makeJacobi(const MatrixBase<Derived>& m, Index p, Index q)
{
return makeJacobi(numext::real(m.coeff(p,p)), m.coeff(p,q), numext::real(m.coeff(q,q)));
@@ -145,6 +158,7 @@
* \sa MatrixBase::applyOnTheLeft(), MatrixBase::applyOnTheRight()
*/
template<typename Scalar>
+EIGEN_DEVICE_FUNC
void JacobiRotation<Scalar>::makeGivens(const Scalar& p, const Scalar& q, Scalar* r)
{
makeGivens(p, q, r, typename internal::conditional<NumTraits<Scalar>::IsComplex, internal::true_type, internal::false_type>::type());
@@ -153,12 +167,13 @@
// specialization for complexes
template<typename Scalar>
+EIGEN_DEVICE_FUNC
void JacobiRotation<Scalar>::makeGivens(const Scalar& p, const Scalar& q, Scalar* r, internal::true_type)
{
using std::sqrt;
using std::abs;
using numext::conj;
-
+
if(q==Scalar(0))
{
m_c = numext::real(p)<0 ? Scalar(-1) : Scalar(1);
@@ -212,6 +227,7 @@
// specialization for reals
template<typename Scalar>
+EIGEN_DEVICE_FUNC
void JacobiRotation<Scalar>::makeGivens(const Scalar& p, const Scalar& q, Scalar* r, internal::false_type)
{
using std::sqrt;
@@ -257,12 +273,13 @@
namespace internal {
/** \jacobi_module
- * Applies the clock wise 2D rotation \a j to the set of 2D vectors of cordinates \a x and \a y:
+ * Applies the clock wise 2D rotation \a j to the set of 2D vectors of coordinates \a x and \a y:
* \f$ \left ( \begin{array}{cc} x \\ y \end{array} \right ) = J \left ( \begin{array}{cc} x \\ y \end{array} \right ) \f$
*
* \sa MatrixBase::applyOnTheLeft(), MatrixBase::applyOnTheRight()
*/
template<typename VectorX, typename VectorY, typename OtherScalar>
+EIGEN_DEVICE_FUNC
void apply_rotation_in_the_plane(DenseBase<VectorX>& xpr_x, DenseBase<VectorY>& xpr_y, const JacobiRotation<OtherScalar>& j);
}
@@ -274,6 +291,7 @@
*/
template<typename Derived>
template<typename OtherScalar>
+EIGEN_DEVICE_FUNC
inline void MatrixBase<Derived>::applyOnTheLeft(Index p, Index q, const JacobiRotation<OtherScalar>& j)
{
RowXpr x(this->row(p));
@@ -289,6 +307,7 @@
*/
template<typename Derived>
template<typename OtherScalar>
+EIGEN_DEVICE_FUNC
inline void MatrixBase<Derived>::applyOnTheRight(Index p, Index q, const JacobiRotation<OtherScalar>& j)
{
ColXpr x(this->col(p));
@@ -302,7 +321,8 @@
int SizeAtCompileTime, int MinAlignment, bool Vectorizable>
struct apply_rotation_in_the_plane_selector
{
- static inline void run(Scalar *x, Index incrx, Scalar *y, Index incry, Index size, OtherScalar c, OtherScalar s)
+ static EIGEN_DEVICE_FUNC
+ inline void run(Scalar *x, Index incrx, Scalar *y, Index incry, Index size, OtherScalar c, OtherScalar s)
{
for(Index i=0; i<size; ++i)
{
@@ -429,10 +449,11 @@
};
template<typename VectorX, typename VectorY, typename OtherScalar>
+EIGEN_DEVICE_FUNC
void /*EIGEN_DONT_INLINE*/ apply_rotation_in_the_plane(DenseBase<VectorX>& xpr_x, DenseBase<VectorY>& xpr_y, const JacobiRotation<OtherScalar>& j)
{
typedef typename VectorX::Scalar Scalar;
- const bool Vectorizable = (VectorX::Flags & VectorY::Flags & PacketAccessBit)
+ const bool Vectorizable = (int(VectorX::Flags) & int(VectorY::Flags) & PacketAccessBit)
&& (int(packet_traits<Scalar>::size) == int(packet_traits<OtherScalar>::size));
eigen_assert(xpr_x.size() == xpr_y.size());
@@ -442,7 +463,7 @@
Scalar* EIGEN_RESTRICT x = &xpr_x.derived().coeffRef(0);
Scalar* EIGEN_RESTRICT y = &xpr_y.derived().coeffRef(0);
-
+
OtherScalar c = j.c();
OtherScalar s = j.s();
if (c==OtherScalar(1) && s==OtherScalar(0))
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/Determinant.h b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/Determinant.h
index d6a3c1e..3a41e6f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/Determinant.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/Determinant.h
@@ -15,6 +15,7 @@
namespace internal {
template<typename Derived>
+EIGEN_DEVICE_FUNC
inline const typename Derived::Scalar bruteforce_det3_helper
(const MatrixBase<Derived>& matrix, int a, int b, int c)
{
@@ -22,14 +23,6 @@
* (matrix.coeff(1,b) * matrix.coeff(2,c) - matrix.coeff(1,c) * matrix.coeff(2,b));
}
-template<typename Derived>
-const typename Derived::Scalar bruteforce_det4_helper
-(const MatrixBase<Derived>& matrix, int j, int k, int m, int n)
-{
- return (matrix.coeff(j,0) * matrix.coeff(k,1) - matrix.coeff(k,0) * matrix.coeff(j,1))
- * (matrix.coeff(m,2) * matrix.coeff(n,3) - matrix.coeff(n,2) * matrix.coeff(m,3));
-}
-
template<typename Derived,
int DeterminantType = Derived::RowsAtCompileTime
> struct determinant_impl
@@ -44,7 +37,8 @@
template<typename Derived> struct determinant_impl<Derived, 1>
{
- static inline typename traits<Derived>::Scalar run(const Derived& m)
+ static inline EIGEN_DEVICE_FUNC
+ typename traits<Derived>::Scalar run(const Derived& m)
{
return m.coeff(0,0);
}
@@ -52,7 +46,8 @@
template<typename Derived> struct determinant_impl<Derived, 2>
{
- static inline typename traits<Derived>::Scalar run(const Derived& m)
+ static inline EIGEN_DEVICE_FUNC
+ typename traits<Derived>::Scalar run(const Derived& m)
{
return m.coeff(0,0) * m.coeff(1,1) - m.coeff(1,0) * m.coeff(0,1);
}
@@ -60,7 +55,8 @@
template<typename Derived> struct determinant_impl<Derived, 3>
{
- static inline typename traits<Derived>::Scalar run(const Derived& m)
+ static inline EIGEN_DEVICE_FUNC
+ typename traits<Derived>::Scalar run(const Derived& m)
{
return bruteforce_det3_helper(m,0,1,2)
- bruteforce_det3_helper(m,1,0,2)
@@ -70,15 +66,34 @@
template<typename Derived> struct determinant_impl<Derived, 4>
{
- static typename traits<Derived>::Scalar run(const Derived& m)
+ typedef typename traits<Derived>::Scalar Scalar;
+ static EIGEN_DEVICE_FUNC
+ Scalar run(const Derived& m)
{
- // trick by Martin Costabel to compute 4x4 det with only 30 muls
- return bruteforce_det4_helper(m,0,1,2,3)
- - bruteforce_det4_helper(m,0,2,1,3)
- + bruteforce_det4_helper(m,0,3,1,2)
- + bruteforce_det4_helper(m,1,2,0,3)
- - bruteforce_det4_helper(m,1,3,0,2)
- + bruteforce_det4_helper(m,2,3,0,1);
+ Scalar d2_01 = det2(m, 0, 1);
+ Scalar d2_02 = det2(m, 0, 2);
+ Scalar d2_03 = det2(m, 0, 3);
+ Scalar d2_12 = det2(m, 1, 2);
+ Scalar d2_13 = det2(m, 1, 3);
+ Scalar d2_23 = det2(m, 2, 3);
+ Scalar d3_0 = det3(m, 1,d2_23, 2,d2_13, 3,d2_12);
+ Scalar d3_1 = det3(m, 0,d2_23, 2,d2_03, 3,d2_02);
+ Scalar d3_2 = det3(m, 0,d2_13, 1,d2_03, 3,d2_01);
+ Scalar d3_3 = det3(m, 0,d2_12, 1,d2_02, 2,d2_01);
+ return internal::pmadd(-m(0,3),d3_0, m(1,3)*d3_1) +
+ internal::pmadd(-m(2,3),d3_2, m(3,3)*d3_3);
+ }
+protected:
+ static EIGEN_DEVICE_FUNC
+ Scalar det2(const Derived& m, Index i0, Index i1)
+ {
+ return m(i0,0) * m(i1,1) - m(i1,0) * m(i0,1);
+ }
+
+ static EIGEN_DEVICE_FUNC
+ Scalar det3(const Derived& m, Index i0, const Scalar& d0, Index i1, const Scalar& d1, Index i2, const Scalar& d2)
+ {
+ return internal::pmadd(m(i0,2), d0, internal::pmadd(-m(i1,2), d1, m(i2,2)*d2));
}
};
@@ -89,6 +104,7 @@
* \returns the determinant of this matrix
*/
template<typename Derived>
+EIGEN_DEVICE_FUNC
inline typename internal::traits<Derived>::Scalar MatrixBase<Derived>::determinant() const
{
eigen_assert(rows() == cols());
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/FullPivLU.h b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/FullPivLU.h
index 03b6af7..ba1749f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/FullPivLU.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/FullPivLU.h
@@ -18,6 +18,7 @@
{
typedef MatrixXpr XprKind;
typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
enum { Flags = 0 };
};
@@ -48,12 +49,12 @@
* The data of the LU decomposition can be directly accessed through the methods matrixLU(),
* permutationP(), permutationQ().
*
- * As an exemple, here is how the original matrix can be retrieved:
+ * As an example, here is how the original matrix can be retrieved:
* \include class_FullPivLU.cpp
* Output: \verbinclude class_FullPivLU.out
*
* This class supports the \link InplaceDecomposition inplace decomposition \endlink mechanism.
- *
+ *
* \sa MatrixBase::fullPivLu(), MatrixBase::determinant(), MatrixBase::inverse()
*/
template<typename _MatrixType> class FullPivLU
@@ -62,9 +63,9 @@
public:
typedef _MatrixType MatrixType;
typedef SolverBase<FullPivLU> Base;
+ friend class SolverBase<FullPivLU>;
EIGEN_GENERIC_PUBLIC_INTERFACE(FullPivLU)
- // FIXME StorageIndex defined in EIGEN_GENERIC_PUBLIC_INTERFACE should be int
enum {
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
@@ -218,6 +219,7 @@
return internal::image_retval<FullPivLU>(*this, originalMatrix);
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** \return a solution x to the equation Ax=b, where A is the matrix of which
* *this is the LU decomposition.
*
@@ -237,14 +239,10 @@
*
* \sa TriangularView::solve(), kernel(), inverse()
*/
- // FIXME this is a copy-paste of the base-class member to add the isInitialized assertion.
template<typename Rhs>
inline const Solve<FullPivLU, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "LU is not initialized.");
- return Solve<FullPivLU, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
/** \returns an estimate of the reciprocal condition number of the matrix of which \c *this is
the LU decomposition.
@@ -320,7 +318,7 @@
return m_usePrescribedThreshold ? m_prescribedThreshold
// this formula comes from experimenting (see "LU precision tuning" thread on the list)
// and turns out to be identical to Higham's formula used already in LDLt.
- : NumTraits<Scalar>::epsilon() * m_lu.diagonalSize();
+ : NumTraits<Scalar>::epsilon() * RealScalar(m_lu.diagonalSize());
}
/** \returns the rank of the matrix of which *this is the LU decomposition.
@@ -406,16 +404,16 @@
MatrixType reconstructedMatrix() const;
- EIGEN_DEVICE_FUNC inline Index rows() const { return m_lu.rows(); }
- EIGEN_DEVICE_FUNC inline Index cols() const { return m_lu.cols(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index rows() const EIGEN_NOEXCEPT { return m_lu.rows(); }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
+ inline Index cols() const EIGEN_NOEXCEPT { return m_lu.cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
template<bool Conjugate, typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
@@ -531,8 +529,8 @@
m_nonzero_pivots = k;
for(Index i = k; i < size; ++i)
{
- m_rowsTranspositions.coeffRef(i) = i;
- m_colsTranspositions.coeffRef(i) = i;
+ m_rowsTranspositions.coeffRef(i) = internal::convert_index<StorageIndex>(i);
+ m_colsTranspositions.coeffRef(i) = internal::convert_index<StorageIndex>(i);
}
break;
}
@@ -543,8 +541,8 @@
// Now that we've found the pivot, we need to apply the row/col swaps to
// bring it to the location (k,k).
- m_rowsTranspositions.coeffRef(k) = row_of_biggest_in_corner;
- m_colsTranspositions.coeffRef(k) = col_of_biggest_in_corner;
+ m_rowsTranspositions.coeffRef(k) = internal::convert_index<StorageIndex>(row_of_biggest_in_corner);
+ m_colsTranspositions.coeffRef(k) = internal::convert_index<StorageIndex>(col_of_biggest_in_corner);
if(k != row_of_biggest_in_corner) {
m_lu.row(k).swap(m_lu.row(row_of_biggest_in_corner));
++number_of_transpositions;
@@ -757,7 +755,6 @@
const Index rows = this->rows(),
cols = this->cols(),
nonzero_pivots = this->rank();
- eigen_assert(rhs.rows() == rows);
const Index smalldim = (std::min)(rows, cols);
if(nonzero_pivots == 0)
@@ -807,7 +804,6 @@
const Index rows = this->rows(), cols = this->cols(),
nonzero_pivots = this->rank();
- eigen_assert(rhs.rows() == cols);
const Index smalldim = (std::min)(rows, cols);
if(nonzero_pivots == 0)
@@ -821,29 +817,19 @@
// Step 1
c = permutationQ().inverse() * rhs;
- if (Conjugate) {
- // Step 2
- m_lu.topLeftCorner(nonzero_pivots, nonzero_pivots)
- .template triangularView<Upper>()
- .adjoint()
- .solveInPlace(c.topRows(nonzero_pivots));
- // Step 3
- m_lu.topLeftCorner(smalldim, smalldim)
- .template triangularView<UnitLower>()
- .adjoint()
- .solveInPlace(c.topRows(smalldim));
- } else {
- // Step 2
- m_lu.topLeftCorner(nonzero_pivots, nonzero_pivots)
- .template triangularView<Upper>()
- .transpose()
- .solveInPlace(c.topRows(nonzero_pivots));
- // Step 3
- m_lu.topLeftCorner(smalldim, smalldim)
- .template triangularView<UnitLower>()
- .transpose()
- .solveInPlace(c.topRows(smalldim));
- }
+ // Step 2
+ m_lu.topLeftCorner(nonzero_pivots, nonzero_pivots)
+ .template triangularView<Upper>()
+ .transpose()
+ .template conjugateIf<Conjugate>()
+ .solveInPlace(c.topRows(nonzero_pivots));
+
+ // Step 3
+ m_lu.topLeftCorner(smalldim, smalldim)
+ .template triangularView<UnitLower>()
+ .transpose()
+ .template conjugateIf<Conjugate>()
+ .solveInPlace(c.topRows(smalldim));
// Step 4
PermutationPType invp = permutationP().inverse().eval();
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/InverseImpl.h b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/InverseImpl.h
index f49f233..a40cefa 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/InverseImpl.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/InverseImpl.h
@@ -77,10 +77,11 @@
const MatrixType& matrix, const typename ResultType::Scalar& invdet,
ResultType& result)
{
+ typename ResultType::Scalar temp = matrix.coeff(0,0);
result.coeffRef(0,0) = matrix.coeff(1,1) * invdet;
result.coeffRef(1,0) = -matrix.coeff(1,0) * invdet;
result.coeffRef(0,1) = -matrix.coeff(0,1) * invdet;
- result.coeffRef(1,1) = matrix.coeff(0,0) * invdet;
+ result.coeffRef(1,1) = temp * invdet;
}
template<typename MatrixType, typename ResultType>
@@ -143,13 +144,18 @@
const Matrix<typename ResultType::Scalar,3,1>& cofactors_col0,
ResultType& result)
{
- result.row(0) = cofactors_col0 * invdet;
- result.coeffRef(1,0) = cofactor_3x3<MatrixType,0,1>(matrix) * invdet;
- result.coeffRef(1,1) = cofactor_3x3<MatrixType,1,1>(matrix) * invdet;
+ // Compute cofactors in a way that avoids aliasing issues.
+ typedef typename ResultType::Scalar Scalar;
+ const Scalar c01 = cofactor_3x3<MatrixType,0,1>(matrix) * invdet;
+ const Scalar c11 = cofactor_3x3<MatrixType,1,1>(matrix) * invdet;
+ const Scalar c02 = cofactor_3x3<MatrixType,0,2>(matrix) * invdet;
result.coeffRef(1,2) = cofactor_3x3<MatrixType,2,1>(matrix) * invdet;
- result.coeffRef(2,0) = cofactor_3x3<MatrixType,0,2>(matrix) * invdet;
result.coeffRef(2,1) = cofactor_3x3<MatrixType,1,2>(matrix) * invdet;
result.coeffRef(2,2) = cofactor_3x3<MatrixType,2,2>(matrix) * invdet;
+ result.coeffRef(1,0) = c01;
+ result.coeffRef(1,1) = c11;
+ result.coeffRef(2,0) = c02;
+ result.row(0) = cofactors_col0 * invdet;
}
template<typename MatrixType, typename ResultType>
@@ -181,14 +187,13 @@
bool& invertible
)
{
- using std::abs;
typedef typename ResultType::Scalar Scalar;
Matrix<Scalar,3,1> cofactors_col0;
cofactors_col0.coeffRef(0) = cofactor_3x3<MatrixType,0,0>(matrix);
cofactors_col0.coeffRef(1) = cofactor_3x3<MatrixType,1,0>(matrix);
cofactors_col0.coeffRef(2) = cofactor_3x3<MatrixType,2,0>(matrix);
determinant = (cofactors_col0.cwiseProduct(matrix.col(0))).sum();
- invertible = abs(determinant) > absDeterminantThreshold;
+ invertible = Eigen::numext::abs(determinant) > absDeterminantThreshold;
if(!invertible) return;
const Scalar invdet = Scalar(1) / determinant;
compute_inverse_size3_helper(matrix, invdet, cofactors_col0, inverse);
@@ -273,7 +278,13 @@
using std::abs;
determinant = matrix.determinant();
invertible = abs(determinant) > absDeterminantThreshold;
- if(invertible) compute_inverse<MatrixType, ResultType>::run(matrix, inverse);
+ if(invertible && extract_data(matrix) != extract_data(inverse)) {
+ compute_inverse<MatrixType, ResultType>::run(matrix, inverse);
+ }
+ else if(invertible) {
+ MatrixType matrix_t = matrix;
+ compute_inverse<MatrixType, ResultType>::run(matrix_t, inverse);
+ }
}
};
@@ -290,6 +301,7 @@
struct Assignment<DstXprType, Inverse<XprType>, internal::assign_op<typename DstXprType::Scalar,typename XprType::Scalar>, Dense2Dense>
{
typedef Inverse<XprType> SrcXprType;
+ EIGEN_DEVICE_FUNC
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename XprType::Scalar> &)
{
Index dstRows = src.rows();
@@ -332,6 +344,7 @@
* \sa computeInverseAndDetWithCheck()
*/
template<typename Derived>
+EIGEN_DEVICE_FUNC
inline const Inverse<Derived> MatrixBase<Derived>::inverse() const
{
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::IsInteger,THIS_FUNCTION_IS_NOT_FOR_INTEGER_NUMERIC_TYPES)
@@ -345,6 +358,8 @@
*
* This is only for fixed-size square matrices of size up to 4x4.
*
+ * Notice that it will trigger a copy of input matrix when trying to do the inverse in place.
+ *
* \param inverse Reference to the matrix in which to store the inverse.
* \param determinant Reference to the variable in which to store the determinant.
* \param invertible Reference to the bool variable in which to store whether the matrix is invertible.
@@ -385,6 +400,8 @@
*
* This is only for fixed-size square matrices of size up to 4x4.
*
+ * Notice that it will trigger a copy of input matrix when trying to do the inverse in place.
+ *
* \param inverse Reference to the matrix in which to store the inverse.
* \param invertible Reference to the bool variable in which to store whether the matrix is invertible.
* \param absDeterminantThreshold Optional parameter controlling the invertibility check.
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/PartialPivLU.h b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/PartialPivLU.h
index d439618..34aed72 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/PartialPivLU.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/PartialPivLU.h
@@ -19,6 +19,7 @@
{
typedef MatrixXpr XprKind;
typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
typedef traits<_MatrixType> BaseTraits;
enum {
Flags = BaseTraits::Flags & RowMajorBit,
@@ -69,7 +70,7 @@
* The data of the LU decomposition can be directly accessed through the methods matrixLU(), permutationP().
*
* This class supports the \link InplaceDecomposition inplace decomposition \endlink mechanism.
- *
+ *
* \sa MatrixBase::partialPivLu(), MatrixBase::determinant(), MatrixBase::inverse(), MatrixBase::computeInverse(), class FullPivLU
*/
template<typename _MatrixType> class PartialPivLU
@@ -79,8 +80,9 @@
typedef _MatrixType MatrixType;
typedef SolverBase<PartialPivLU> Base;
+ friend class SolverBase<PartialPivLU>;
+
EIGEN_GENERIC_PUBLIC_INTERFACE(PartialPivLU)
- // FIXME StorageIndex defined in EIGEN_GENERIC_PUBLIC_INTERFACE should be int
enum {
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
@@ -152,6 +154,7 @@
return m_p;
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** This method returns the solution x to the equation Ax=b, where A is the matrix of which
* *this is the LU decomposition.
*
@@ -169,14 +172,10 @@
*
* \sa TriangularView::solve(), inverse(), computeInverse()
*/
- // FIXME this is a copy-paste of the base-class member to add the isInitialized assertion.
template<typename Rhs>
inline const Solve<PartialPivLU, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "PartialPivLU is not initialized.");
- return Solve<PartialPivLU, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
/** \returns an estimate of the reciprocal condition number of the matrix of which \c *this is
the LU decomposition.
@@ -217,8 +216,8 @@
MatrixType reconstructedMatrix() const;
- inline Index rows() const { return m_lu.rows(); }
- inline Index cols() const { return m_lu.cols(); }
+ EIGEN_CONSTEXPR inline Index rows() const EIGEN_NOEXCEPT { return m_lu.rows(); }
+ EIGEN_CONSTEXPR inline Index cols() const EIGEN_NOEXCEPT { return m_lu.cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
@@ -231,8 +230,6 @@
* Step 3: replace c by the solution x to Ux = c.
*/
- eigen_assert(rhs.rows() == m_lu.rows());
-
// Step 1
dst = permutationP() * rhs;
@@ -246,26 +243,21 @@
template<bool Conjugate, typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC
void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const {
- /* The decomposition PA = LU can be rewritten as A = P^{-1} L U.
+ /* The decomposition PA = LU can be rewritten as A^T = U^T L^T P.
* So we proceed as follows:
- * Step 1: compute c = Pb.
- * Step 2: replace c by the solution x to Lx = c.
- * Step 3: replace c by the solution x to Ux = c.
+ * Step 1: compute c as the solution to L^T c = b
+ * Step 2: replace c by the solution x to U^T x = c.
+ * Step 3: update c = P^-1 c.
*/
eigen_assert(rhs.rows() == m_lu.cols());
- if (Conjugate) {
- // Step 1
- dst = m_lu.template triangularView<Upper>().adjoint().solve(rhs);
- // Step 2
- m_lu.template triangularView<UnitLower>().adjoint().solveInPlace(dst);
- } else {
- // Step 1
- dst = m_lu.template triangularView<Upper>().transpose().solve(rhs);
- // Step 2
- m_lu.template triangularView<UnitLower>().transpose().solveInPlace(dst);
- }
+ // Step 1
+ dst = m_lu.template triangularView<Upper>().transpose()
+ .template conjugateIf<Conjugate>().solve(rhs);
+ // Step 2
+ m_lu.template triangularView<UnitLower>().transpose()
+ .template conjugateIf<Conjugate>().solveInPlace(dst);
// Step 3
dst = permutationP().transpose() * dst;
}
@@ -339,17 +331,18 @@
namespace internal {
/** \internal This is the blocked version of fullpivlu_unblocked() */
-template<typename Scalar, int StorageOrder, typename PivIndex>
+template<typename Scalar, int StorageOrder, typename PivIndex, int SizeAtCompileTime=Dynamic>
struct partial_lu_impl
{
- // FIXME add a stride to Map, so that the following mapping becomes easier,
- // another option would be to create an expression being able to automatically
- // warp any Map, Matrix, and Block expressions as a unique type, but since that's exactly
- // a Map + stride, why not adding a stride to Map, and convenient ctors from a Matrix,
- // and Block.
- typedef Map<Matrix<Scalar, Dynamic, Dynamic, StorageOrder> > MapLU;
- typedef Block<MapLU, Dynamic, Dynamic> MatrixType;
- typedef Block<MatrixType,Dynamic,Dynamic> BlockType;
+ static const int UnBlockedBound = 16;
+ static const bool UnBlockedAtCompileTime = SizeAtCompileTime!=Dynamic && SizeAtCompileTime<=UnBlockedBound;
+ static const int ActualSizeAtCompileTime = UnBlockedAtCompileTime ? SizeAtCompileTime : Dynamic;
+ // Remaining rows and columns at compile-time:
+ static const int RRows = SizeAtCompileTime==2 ? 1 : Dynamic;
+ static const int RCols = SizeAtCompileTime==2 ? 1 : Dynamic;
+ typedef Matrix<Scalar, ActualSizeAtCompileTime, ActualSizeAtCompileTime, StorageOrder> MatrixType;
+ typedef Ref<MatrixType> MatrixTypeRef;
+ typedef Ref<Matrix<Scalar, Dynamic, Dynamic, StorageOrder> > BlockType;
typedef typename MatrixType::RealScalar RealScalar;
/** \internal performs the LU decomposition in-place of the matrix \a lu
@@ -362,19 +355,22 @@
*
* \returns The index of the first pivot which is exactly zero if any, or a negative number otherwise.
*/
- static Index unblocked_lu(MatrixType& lu, PivIndex* row_transpositions, PivIndex& nb_transpositions)
+ static Index unblocked_lu(MatrixTypeRef& lu, PivIndex* row_transpositions, PivIndex& nb_transpositions)
{
typedef scalar_score_coeff_op<Scalar> Scoring;
typedef typename Scoring::result_type Score;
const Index rows = lu.rows();
const Index cols = lu.cols();
const Index size = (std::min)(rows,cols);
+ // For small compile-time matrices it is worth processing the last row separately:
+ // speedup: +100% for 2x2, +10% for others.
+ const Index endk = UnBlockedAtCompileTime ? size-1 : size;
nb_transpositions = 0;
Index first_zero_pivot = -1;
- for(Index k = 0; k < size; ++k)
+ for(Index k = 0; k < endk; ++k)
{
- Index rrows = rows-k-1;
- Index rcols = cols-k-1;
+ int rrows = internal::convert_index<int>(rows-k-1);
+ int rcols = internal::convert_index<int>(cols-k-1);
Index row_of_biggest_in_col;
Score biggest_in_corner
@@ -391,9 +387,7 @@
++nb_transpositions;
}
- // FIXME shall we introduce a safe quotient expression in cas 1/lu.coeff(k,k)
- // overflow but not the actual quotient?
- lu.col(k).tail(rrows) /= lu.coeff(k,k);
+ lu.col(k).tail(fix<RRows>(rrows)) /= lu.coeff(k,k);
}
else if(first_zero_pivot==-1)
{
@@ -403,8 +397,18 @@
}
if(k<rows-1)
- lu.bottomRightCorner(rrows,rcols).noalias() -= lu.col(k).tail(rrows) * lu.row(k).tail(rcols);
+ lu.bottomRightCorner(fix<RRows>(rrows),fix<RCols>(rcols)).noalias() -= lu.col(k).tail(fix<RRows>(rrows)) * lu.row(k).tail(fix<RCols>(rcols));
}
+
+ // special handling of the last entry
+ if(UnBlockedAtCompileTime)
+ {
+ Index k = endk;
+ row_transpositions[k] = PivIndex(k);
+ if (Scoring()(lu(k, k)) == Score(0) && first_zero_pivot == -1)
+ first_zero_pivot = k;
+ }
+
return first_zero_pivot;
}
@@ -420,18 +424,17 @@
* \returns The index of the first pivot which is exactly zero if any, or a negative number otherwise.
*
* \note This very low level interface using pointers, etc. is to:
- * 1 - reduce the number of instanciations to the strict minimum
- * 2 - avoid infinite recursion of the instanciations with Block<Block<Block<...> > >
+ * 1 - reduce the number of instantiations to the strict minimum
+ * 2 - avoid infinite recursion of the instantiations with Block<Block<Block<...> > >
*/
static Index blocked_lu(Index rows, Index cols, Scalar* lu_data, Index luStride, PivIndex* row_transpositions, PivIndex& nb_transpositions, Index maxBlockSize=256)
{
- MapLU lu1(lu_data,StorageOrder==RowMajor?rows:luStride,StorageOrder==RowMajor?luStride:cols);
- MatrixType lu(lu1,0,0,rows,cols);
+ MatrixTypeRef lu = MatrixType::Map(lu_data,rows, cols, OuterStride<>(luStride));
const Index size = (std::min)(rows,cols);
// if the matrix is too small, no blocking:
- if(size<=16)
+ if(UnBlockedAtCompileTime || size<=UnBlockedBound)
{
return unblocked_lu(lu, row_transpositions, nb_transpositions);
}
@@ -457,12 +460,12 @@
// A00 | A01 | A02
// lu = A_0 | A_1 | A_2 = A10 | A11 | A12
// A20 | A21 | A22
- BlockType A_0(lu,0,0,rows,k);
- BlockType A_2(lu,0,k+bs,rows,tsize);
- BlockType A11(lu,k,k,bs,bs);
- BlockType A12(lu,k,k+bs,bs,tsize);
- BlockType A21(lu,k+bs,k,trows,bs);
- BlockType A22(lu,k+bs,k+bs,trows,tsize);
+ BlockType A_0 = lu.block(0,0,rows,k);
+ BlockType A_2 = lu.block(0,k+bs,rows,tsize);
+ BlockType A11 = lu.block(k,k,bs,bs);
+ BlockType A12 = lu.block(k,k+bs,bs,tsize);
+ BlockType A21 = lu.block(k+bs,k,trows,bs);
+ BlockType A22 = lu.block(k+bs,k+bs,trows,tsize);
PivIndex nb_transpositions_in_panel;
// recursively call the blocked LU algorithm on [A11^T A21^T]^T
@@ -501,11 +504,18 @@
template<typename MatrixType, typename TranspositionType>
void partial_lu_inplace(MatrixType& lu, TranspositionType& row_transpositions, typename TranspositionType::StorageIndex& nb_transpositions)
{
+ // Special-case of zero matrix.
+ if (lu.rows() == 0 || lu.cols() == 0) {
+ nb_transpositions = 0;
+ return;
+ }
eigen_assert(lu.cols() == row_transpositions.size());
- eigen_assert((&row_transpositions.coeffRef(1)-&row_transpositions.coeffRef(0)) == 1);
+ eigen_assert(row_transpositions.size() < 2 || (&row_transpositions.coeffRef(1)-&row_transpositions.coeffRef(0)) == 1);
partial_lu_impl
- <typename MatrixType::Scalar, MatrixType::Flags&RowMajorBit?RowMajor:ColMajor, typename TranspositionType::StorageIndex>
+ < typename MatrixType::Scalar, MatrixType::Flags&RowMajorBit?RowMajor:ColMajor,
+ typename TranspositionType::StorageIndex,
+ EIGEN_SIZE_MIN_PREFER_FIXED(MatrixType::RowsAtCompileTime,MatrixType::ColsAtCompileTime)>
::blocked_lu(lu.rows(), lu.cols(), &lu.coeffRef(0,0), lu.outerStride(), &row_transpositions.coeffRef(0), nb_transpositions);
}
@@ -519,7 +529,10 @@
// the row permutation is stored as int indices, so just to be sure:
eigen_assert(m_lu.rows()<NumTraits<int>::highest());
- m_l1_norm = m_lu.cwiseAbs().colwise().sum().maxCoeff();
+ if(m_lu.cols()>0)
+ m_l1_norm = m_lu.cwiseAbs().colwise().sum().maxCoeff();
+ else
+ m_l1_norm = RealScalar(0);
eigen_assert(m_lu.rows() == m_lu.cols() && "PartialPivLU is only for square (and moreover invertible) matrices");
const Index size = m_lu.rows();
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/arch/InverseSize4.h b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/arch/InverseSize4.h
new file mode 100644
index 0000000..a232ffc
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/arch/InverseSize4.h
@@ -0,0 +1,351 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2001 Intel Corporation
+// Copyright (C) 2010 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+//
+// The algorithm below is a reimplementation of former \src\LU\Inverse_SSE.h using PacketMath.
+// inv(M) = M#/|M|, where inv(M), M# and |M| denote the inverse of M,
+// adjugate of M and determinant of M respectively. M# is computed block-wise
+// using specific formulae. For proof, see:
+// https://lxjk.github.io/2017/09/03/Fast-4x4-Matrix-Inverse-with-SSE-SIMD-Explained.html
+// Variable names are adopted from \src\LU\Inverse_SSE.h.
+//
+// The SSE code for the 4x4 float and double matrix inverse in former (deprecated) \src\LU\Inverse_SSE.h
+// comes from the following Intel's library:
+// http://software.intel.com/en-us/articles/optimized-matrix-library-for-use-with-the-intel-pentiumr-4-processors-sse2-instructions/
+//
+// Here is the respective copyright and license statement:
+//
+// Copyright (c) 2001 Intel Corporation.
+//
+// Permition is granted to use, copy, distribute and prepare derivative works
+// of this library for any purpose and without fee, provided, that the above
+// copyright notice and this statement appear in all copies.
+// Intel makes no representations about the suitability of this software for
+// any purpose, and specifically disclaims all warranties.
+// See LEGAL.TXT for all the legal information.
+//
+// TODO: Unify implementations of different data types (i.e. float and double).
+#ifndef EIGEN_INVERSE_SIZE_4_H
+#define EIGEN_INVERSE_SIZE_4_H
+
+namespace Eigen
+{
+namespace internal
+{
+template <typename MatrixType, typename ResultType>
+struct compute_inverse_size4<Architecture::Target, float, MatrixType, ResultType>
+{
+ enum
+ {
+ MatrixAlignment = traits<MatrixType>::Alignment,
+ ResultAlignment = traits<ResultType>::Alignment,
+ StorageOrdersMatch = (MatrixType::Flags & RowMajorBit) == (ResultType::Flags & RowMajorBit)
+ };
+ typedef typename conditional<(MatrixType::Flags & LinearAccessBit), MatrixType const &, typename MatrixType::PlainObject>::type ActualMatrixType;
+
+ static void run(const MatrixType &mat, ResultType &result)
+ {
+ ActualMatrixType matrix(mat);
+
+ const float* data = matrix.data();
+ const Index stride = matrix.innerStride();
+ Packet4f _L1 = ploadt<Packet4f,MatrixAlignment>(data);
+ Packet4f _L2 = ploadt<Packet4f,MatrixAlignment>(data + stride*4);
+ Packet4f _L3 = ploadt<Packet4f,MatrixAlignment>(data + stride*8);
+ Packet4f _L4 = ploadt<Packet4f,MatrixAlignment>(data + stride*12);
+
+ // Four 2x2 sub-matrices of the input matrix
+ // input = [[A, B],
+ // [C, D]]
+ Packet4f A, B, C, D;
+
+ if (!StorageOrdersMatch)
+ {
+ A = vec4f_unpacklo(_L1, _L2);
+ B = vec4f_unpacklo(_L3, _L4);
+ C = vec4f_unpackhi(_L1, _L2);
+ D = vec4f_unpackhi(_L3, _L4);
+ }
+ else
+ {
+ A = vec4f_movelh(_L1, _L2);
+ B = vec4f_movehl(_L2, _L1);
+ C = vec4f_movelh(_L3, _L4);
+ D = vec4f_movehl(_L4, _L3);
+ }
+
+ Packet4f AB, DC;
+
+ // AB = A# * B, where A# denotes the adjugate of A, and * denotes matrix product.
+ AB = pmul(vec4f_swizzle2(A, A, 3, 3, 0, 0), B);
+ AB = psub(AB, pmul(vec4f_swizzle2(A, A, 1, 1, 2, 2), vec4f_swizzle2(B, B, 2, 3, 0, 1)));
+
+ // DC = D#*C
+ DC = pmul(vec4f_swizzle2(D, D, 3, 3, 0, 0), C);
+ DC = psub(DC, pmul(vec4f_swizzle2(D, D, 1, 1, 2, 2), vec4f_swizzle2(C, C, 2, 3, 0, 1)));
+
+ // determinants of the sub-matrices
+ Packet4f dA, dB, dC, dD;
+
+ dA = pmul(vec4f_swizzle2(A, A, 3, 3, 1, 1), A);
+ dA = psub(dA, vec4f_movehl(dA, dA));
+
+ dB = pmul(vec4f_swizzle2(B, B, 3, 3, 1, 1), B);
+ dB = psub(dB, vec4f_movehl(dB, dB));
+
+ dC = pmul(vec4f_swizzle2(C, C, 3, 3, 1, 1), C);
+ dC = psub(dC, vec4f_movehl(dC, dC));
+
+ dD = pmul(vec4f_swizzle2(D, D, 3, 3, 1, 1), D);
+ dD = psub(dD, vec4f_movehl(dD, dD));
+
+ Packet4f d, d1, d2;
+
+ d = pmul(vec4f_swizzle2(DC, DC, 0, 2, 1, 3), AB);
+ d = padd(d, vec4f_movehl(d, d));
+ d = padd(d, vec4f_swizzle2(d, d, 1, 0, 0, 0));
+ d1 = pmul(dA, dD);
+ d2 = pmul(dB, dC);
+
+ // determinant of the input matrix, det = |A||D| + |B||C| - trace(A#*B*D#*C)
+ Packet4f det = vec4f_duplane(psub(padd(d1, d2), d), 0);
+
+ // reciprocal of the determinant of the input matrix, rd = 1/det
+ Packet4f rd = pdiv(pset1<Packet4f>(1.0f), det);
+
+ // Four sub-matrices of the inverse
+ Packet4f iA, iB, iC, iD;
+
+ // iD = D*|A| - C*A#*B
+ iD = pmul(vec4f_swizzle2(C, C, 0, 0, 2, 2), vec4f_movelh(AB, AB));
+ iD = padd(iD, pmul(vec4f_swizzle2(C, C, 1, 1, 3, 3), vec4f_movehl(AB, AB)));
+ iD = psub(pmul(D, vec4f_duplane(dA, 0)), iD);
+
+ // iA = A*|D| - B*D#*C
+ iA = pmul(vec4f_swizzle2(B, B, 0, 0, 2, 2), vec4f_movelh(DC, DC));
+ iA = padd(iA, pmul(vec4f_swizzle2(B, B, 1, 1, 3, 3), vec4f_movehl(DC, DC)));
+ iA = psub(pmul(A, vec4f_duplane(dD, 0)), iA);
+
+ // iB = C*|B| - D * (A#B)# = C*|B| - D*B#*A
+ iB = pmul(D, vec4f_swizzle2(AB, AB, 3, 0, 3, 0));
+ iB = psub(iB, pmul(vec4f_swizzle2(D, D, 1, 0, 3, 2), vec4f_swizzle2(AB, AB, 2, 1, 2, 1)));
+ iB = psub(pmul(C, vec4f_duplane(dB, 0)), iB);
+
+ // iC = B*|C| - A * (D#C)# = B*|C| - A*C#*D
+ iC = pmul(A, vec4f_swizzle2(DC, DC, 3, 0, 3, 0));
+ iC = psub(iC, pmul(vec4f_swizzle2(A, A, 1, 0, 3, 2), vec4f_swizzle2(DC, DC, 2, 1, 2, 1)));
+ iC = psub(pmul(B, vec4f_duplane(dC, 0)), iC);
+
+ const float sign_mask[4] = {0.0f, numext::bit_cast<float>(0x80000000u), numext::bit_cast<float>(0x80000000u), 0.0f};
+ const Packet4f p4f_sign_PNNP = ploadu<Packet4f>(sign_mask);
+ rd = pxor(rd, p4f_sign_PNNP);
+ iA = pmul(iA, rd);
+ iB = pmul(iB, rd);
+ iC = pmul(iC, rd);
+ iD = pmul(iD, rd);
+
+ Index res_stride = result.outerStride();
+ float *res = result.data();
+
+ pstoret<float, Packet4f, ResultAlignment>(res + 0, vec4f_swizzle2(iA, iB, 3, 1, 3, 1));
+ pstoret<float, Packet4f, ResultAlignment>(res + res_stride, vec4f_swizzle2(iA, iB, 2, 0, 2, 0));
+ pstoret<float, Packet4f, ResultAlignment>(res + 2 * res_stride, vec4f_swizzle2(iC, iD, 3, 1, 3, 1));
+ pstoret<float, Packet4f, ResultAlignment>(res + 3 * res_stride, vec4f_swizzle2(iC, iD, 2, 0, 2, 0));
+ }
+};
+
+#if !(defined EIGEN_VECTORIZE_NEON && !(EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG))
+// same algorithm as above, except that each operand is split into
+// halves for two registers to hold.
+template <typename MatrixType, typename ResultType>
+struct compute_inverse_size4<Architecture::Target, double, MatrixType, ResultType>
+{
+ enum
+ {
+ MatrixAlignment = traits<MatrixType>::Alignment,
+ ResultAlignment = traits<ResultType>::Alignment,
+ StorageOrdersMatch = (MatrixType::Flags & RowMajorBit) == (ResultType::Flags & RowMajorBit)
+ };
+ typedef typename conditional<(MatrixType::Flags & LinearAccessBit),
+ MatrixType const &,
+ typename MatrixType::PlainObject>::type
+ ActualMatrixType;
+
+ static void run(const MatrixType &mat, ResultType &result)
+ {
+ ActualMatrixType matrix(mat);
+
+ // Four 2x2 sub-matrices of the input matrix, each is further divided into upper and lower
+ // row e.g. A1, upper row of A, A2, lower row of A
+ // input = [[A, B], = [[[A1, [B1,
+ // [C, D]] A2], B2]],
+ // [[C1, [D1,
+ // C2], D2]]]
+
+ Packet2d A1, A2, B1, B2, C1, C2, D1, D2;
+
+ const double* data = matrix.data();
+ const Index stride = matrix.innerStride();
+ if (StorageOrdersMatch)
+ {
+ A1 = ploadt<Packet2d,MatrixAlignment>(data + stride*0);
+ B1 = ploadt<Packet2d,MatrixAlignment>(data + stride*2);
+ A2 = ploadt<Packet2d,MatrixAlignment>(data + stride*4);
+ B2 = ploadt<Packet2d,MatrixAlignment>(data + stride*6);
+ C1 = ploadt<Packet2d,MatrixAlignment>(data + stride*8);
+ D1 = ploadt<Packet2d,MatrixAlignment>(data + stride*10);
+ C2 = ploadt<Packet2d,MatrixAlignment>(data + stride*12);
+ D2 = ploadt<Packet2d,MatrixAlignment>(data + stride*14);
+ }
+ else
+ {
+ Packet2d temp;
+ A1 = ploadt<Packet2d,MatrixAlignment>(data + stride*0);
+ C1 = ploadt<Packet2d,MatrixAlignment>(data + stride*2);
+ A2 = ploadt<Packet2d,MatrixAlignment>(data + stride*4);
+ C2 = ploadt<Packet2d,MatrixAlignment>(data + stride*6);
+ temp = A1;
+ A1 = vec2d_unpacklo(A1, A2);
+ A2 = vec2d_unpackhi(temp, A2);
+
+ temp = C1;
+ C1 = vec2d_unpacklo(C1, C2);
+ C2 = vec2d_unpackhi(temp, C2);
+
+ B1 = ploadt<Packet2d,MatrixAlignment>(data + stride*8);
+ D1 = ploadt<Packet2d,MatrixAlignment>(data + stride*10);
+ B2 = ploadt<Packet2d,MatrixAlignment>(data + stride*12);
+ D2 = ploadt<Packet2d,MatrixAlignment>(data + stride*14);
+
+ temp = B1;
+ B1 = vec2d_unpacklo(B1, B2);
+ B2 = vec2d_unpackhi(temp, B2);
+
+ temp = D1;
+ D1 = vec2d_unpacklo(D1, D2);
+ D2 = vec2d_unpackhi(temp, D2);
+ }
+
+ // determinants of the sub-matrices
+ Packet2d dA, dB, dC, dD;
+
+ dA = vec2d_swizzle2(A2, A2, 1);
+ dA = pmul(A1, dA);
+ dA = psub(dA, vec2d_duplane(dA, 1));
+
+ dB = vec2d_swizzle2(B2, B2, 1);
+ dB = pmul(B1, dB);
+ dB = psub(dB, vec2d_duplane(dB, 1));
+
+ dC = vec2d_swizzle2(C2, C2, 1);
+ dC = pmul(C1, dC);
+ dC = psub(dC, vec2d_duplane(dC, 1));
+
+ dD = vec2d_swizzle2(D2, D2, 1);
+ dD = pmul(D1, dD);
+ dD = psub(dD, vec2d_duplane(dD, 1));
+
+ Packet2d DC1, DC2, AB1, AB2;
+
+ // AB = A# * B, where A# denotes the adjugate of A, and * denotes matrix product.
+ AB1 = pmul(B1, vec2d_duplane(A2, 1));
+ AB2 = pmul(B2, vec2d_duplane(A1, 0));
+ AB1 = psub(AB1, pmul(B2, vec2d_duplane(A1, 1)));
+ AB2 = psub(AB2, pmul(B1, vec2d_duplane(A2, 0)));
+
+ // DC = D#*C
+ DC1 = pmul(C1, vec2d_duplane(D2, 1));
+ DC2 = pmul(C2, vec2d_duplane(D1, 0));
+ DC1 = psub(DC1, pmul(C2, vec2d_duplane(D1, 1)));
+ DC2 = psub(DC2, pmul(C1, vec2d_duplane(D2, 0)));
+
+ Packet2d d1, d2;
+
+ // determinant of the input matrix, det = |A||D| + |B||C| - trace(A#*B*D#*C)
+ Packet2d det;
+
+ // reciprocal of the determinant of the input matrix, rd = 1/det
+ Packet2d rd;
+
+ d1 = pmul(AB1, vec2d_swizzle2(DC1, DC2, 0));
+ d2 = pmul(AB2, vec2d_swizzle2(DC1, DC2, 3));
+ rd = padd(d1, d2);
+ rd = padd(rd, vec2d_duplane(rd, 1));
+
+ d1 = pmul(dA, dD);
+ d2 = pmul(dB, dC);
+
+ det = padd(d1, d2);
+ det = psub(det, rd);
+ det = vec2d_duplane(det, 0);
+ rd = pdiv(pset1<Packet2d>(1.0), det);
+
+ // rows of four sub-matrices of the inverse
+ Packet2d iA1, iA2, iB1, iB2, iC1, iC2, iD1, iD2;
+
+ // iD = D*|A| - C*A#*B
+ iD1 = pmul(AB1, vec2d_duplane(C1, 0));
+ iD2 = pmul(AB1, vec2d_duplane(C2, 0));
+ iD1 = padd(iD1, pmul(AB2, vec2d_duplane(C1, 1)));
+ iD2 = padd(iD2, pmul(AB2, vec2d_duplane(C2, 1)));
+ dA = vec2d_duplane(dA, 0);
+ iD1 = psub(pmul(D1, dA), iD1);
+ iD2 = psub(pmul(D2, dA), iD2);
+
+ // iA = A*|D| - B*D#*C
+ iA1 = pmul(DC1, vec2d_duplane(B1, 0));
+ iA2 = pmul(DC1, vec2d_duplane(B2, 0));
+ iA1 = padd(iA1, pmul(DC2, vec2d_duplane(B1, 1)));
+ iA2 = padd(iA2, pmul(DC2, vec2d_duplane(B2, 1)));
+ dD = vec2d_duplane(dD, 0);
+ iA1 = psub(pmul(A1, dD), iA1);
+ iA2 = psub(pmul(A2, dD), iA2);
+
+ // iB = C*|B| - D * (A#B)# = C*|B| - D*B#*A
+ iB1 = pmul(D1, vec2d_swizzle2(AB2, AB1, 1));
+ iB2 = pmul(D2, vec2d_swizzle2(AB2, AB1, 1));
+ iB1 = psub(iB1, pmul(vec2d_swizzle2(D1, D1, 1), vec2d_swizzle2(AB2, AB1, 2)));
+ iB2 = psub(iB2, pmul(vec2d_swizzle2(D2, D2, 1), vec2d_swizzle2(AB2, AB1, 2)));
+ dB = vec2d_duplane(dB, 0);
+ iB1 = psub(pmul(C1, dB), iB1);
+ iB2 = psub(pmul(C2, dB), iB2);
+
+ // iC = B*|C| - A * (D#C)# = B*|C| - A*C#*D
+ iC1 = pmul(A1, vec2d_swizzle2(DC2, DC1, 1));
+ iC2 = pmul(A2, vec2d_swizzle2(DC2, DC1, 1));
+ iC1 = psub(iC1, pmul(vec2d_swizzle2(A1, A1, 1), vec2d_swizzle2(DC2, DC1, 2)));
+ iC2 = psub(iC2, pmul(vec2d_swizzle2(A2, A2, 1), vec2d_swizzle2(DC2, DC1, 2)));
+ dC = vec2d_duplane(dC, 0);
+ iC1 = psub(pmul(B1, dC), iC1);
+ iC2 = psub(pmul(B2, dC), iC2);
+
+ const double sign_mask1[2] = {0.0, numext::bit_cast<double>(0x8000000000000000ull)};
+ const double sign_mask2[2] = {numext::bit_cast<double>(0x8000000000000000ull), 0.0};
+ const Packet2d sign_PN = ploadu<Packet2d>(sign_mask1);
+ const Packet2d sign_NP = ploadu<Packet2d>(sign_mask2);
+ d1 = pxor(rd, sign_PN);
+ d2 = pxor(rd, sign_NP);
+
+ Index res_stride = result.outerStride();
+ double *res = result.data();
+ pstoret<double, Packet2d, ResultAlignment>(res + 0, pmul(vec2d_swizzle2(iA2, iA1, 3), d1));
+ pstoret<double, Packet2d, ResultAlignment>(res + res_stride, pmul(vec2d_swizzle2(iA2, iA1, 0), d2));
+ pstoret<double, Packet2d, ResultAlignment>(res + 2, pmul(vec2d_swizzle2(iB2, iB1, 3), d1));
+ pstoret<double, Packet2d, ResultAlignment>(res + res_stride + 2, pmul(vec2d_swizzle2(iB2, iB1, 0), d2));
+ pstoret<double, Packet2d, ResultAlignment>(res + 2 * res_stride, pmul(vec2d_swizzle2(iC2, iC1, 3), d1));
+ pstoret<double, Packet2d, ResultAlignment>(res + 3 * res_stride, pmul(vec2d_swizzle2(iC2, iC1, 0), d2));
+ pstoret<double, Packet2d, ResultAlignment>(res + 2 * res_stride + 2, pmul(vec2d_swizzle2(iD2, iD1, 3), d1));
+ pstoret<double, Packet2d, ResultAlignment>(res + 3 * res_stride + 2, pmul(vec2d_swizzle2(iD2, iD1, 0), d2));
+ }
+};
+#endif
+} // namespace internal
+} // namespace Eigen
+#endif
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/arch/Inverse_SSE.h b/wpimath/src/main/native/eigeninclude/Eigen/src/LU/arch/Inverse_SSE.h
deleted file mode 100644
index ebb64a6..0000000
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/LU/arch/Inverse_SSE.h
+++ /dev/null
@@ -1,338 +0,0 @@
-// This file is part of Eigen, a lightweight C++ template library
-// for linear algebra.
-//
-// Copyright (C) 2001 Intel Corporation
-// Copyright (C) 2010 Gael Guennebaud <gael.guennebaud@inria.fr>
-// Copyright (C) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
-//
-// This Source Code Form is subject to the terms of the Mozilla
-// Public License v. 2.0. If a copy of the MPL was not distributed
-// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-// The SSE code for the 4x4 float and double matrix inverse in this file
-// comes from the following Intel's library:
-// http://software.intel.com/en-us/articles/optimized-matrix-library-for-use-with-the-intel-pentiumr-4-processors-sse2-instructions/
-//
-// Here is the respective copyright and license statement:
-//
-// Copyright (c) 2001 Intel Corporation.
-//
-// Permition is granted to use, copy, distribute and prepare derivative works
-// of this library for any purpose and without fee, provided, that the above
-// copyright notice and this statement appear in all copies.
-// Intel makes no representations about the suitability of this software for
-// any purpose, and specifically disclaims all warranties.
-// See LEGAL.TXT for all the legal information.
-
-#ifndef EIGEN_INVERSE_SSE_H
-#define EIGEN_INVERSE_SSE_H
-
-namespace Eigen {
-
-namespace internal {
-
-template<typename MatrixType, typename ResultType>
-struct compute_inverse_size4<Architecture::SSE, float, MatrixType, ResultType>
-{
- enum {
- MatrixAlignment = traits<MatrixType>::Alignment,
- ResultAlignment = traits<ResultType>::Alignment,
- StorageOrdersMatch = (MatrixType::Flags&RowMajorBit) == (ResultType::Flags&RowMajorBit)
- };
- typedef typename conditional<(MatrixType::Flags&LinearAccessBit),MatrixType const &,typename MatrixType::PlainObject>::type ActualMatrixType;
-
- static void run(const MatrixType& mat, ResultType& result)
- {
- ActualMatrixType matrix(mat);
- EIGEN_ALIGN16 const unsigned int _Sign_PNNP[4] = { 0x00000000, 0x80000000, 0x80000000, 0x00000000 };
-
- // Load the full matrix into registers
- __m128 _L1 = matrix.template packet<MatrixAlignment>( 0);
- __m128 _L2 = matrix.template packet<MatrixAlignment>( 4);
- __m128 _L3 = matrix.template packet<MatrixAlignment>( 8);
- __m128 _L4 = matrix.template packet<MatrixAlignment>(12);
-
- // The inverse is calculated using "Divide and Conquer" technique. The
- // original matrix is divide into four 2x2 sub-matrices. Since each
- // register holds four matrix element, the smaller matrices are
- // represented as a registers. Hence we get a better locality of the
- // calculations.
-
- __m128 A, B, C, D; // the four sub-matrices
- if(!StorageOrdersMatch)
- {
- A = _mm_unpacklo_ps(_L1, _L2);
- B = _mm_unpacklo_ps(_L3, _L4);
- C = _mm_unpackhi_ps(_L1, _L2);
- D = _mm_unpackhi_ps(_L3, _L4);
- }
- else
- {
- A = _mm_movelh_ps(_L1, _L2);
- B = _mm_movehl_ps(_L2, _L1);
- C = _mm_movelh_ps(_L3, _L4);
- D = _mm_movehl_ps(_L4, _L3);
- }
-
- __m128 iA, iB, iC, iD, // partial inverse of the sub-matrices
- DC, AB;
- __m128 dA, dB, dC, dD; // determinant of the sub-matrices
- __m128 det, d, d1, d2;
- __m128 rd; // reciprocal of the determinant
-
- // AB = A# * B
- AB = _mm_mul_ps(_mm_shuffle_ps(A,A,0x0F), B);
- AB = _mm_sub_ps(AB,_mm_mul_ps(_mm_shuffle_ps(A,A,0xA5), _mm_shuffle_ps(B,B,0x4E)));
- // DC = D# * C
- DC = _mm_mul_ps(_mm_shuffle_ps(D,D,0x0F), C);
- DC = _mm_sub_ps(DC,_mm_mul_ps(_mm_shuffle_ps(D,D,0xA5), _mm_shuffle_ps(C,C,0x4E)));
-
- // dA = |A|
- dA = _mm_mul_ps(_mm_shuffle_ps(A, A, 0x5F),A);
- dA = _mm_sub_ss(dA, _mm_movehl_ps(dA,dA));
- // dB = |B|
- dB = _mm_mul_ps(_mm_shuffle_ps(B, B, 0x5F),B);
- dB = _mm_sub_ss(dB, _mm_movehl_ps(dB,dB));
-
- // dC = |C|
- dC = _mm_mul_ps(_mm_shuffle_ps(C, C, 0x5F),C);
- dC = _mm_sub_ss(dC, _mm_movehl_ps(dC,dC));
- // dD = |D|
- dD = _mm_mul_ps(_mm_shuffle_ps(D, D, 0x5F),D);
- dD = _mm_sub_ss(dD, _mm_movehl_ps(dD,dD));
-
- // d = trace(AB*DC) = trace(A#*B*D#*C)
- d = _mm_mul_ps(_mm_shuffle_ps(DC,DC,0xD8),AB);
-
- // iD = C*A#*B
- iD = _mm_mul_ps(_mm_shuffle_ps(C,C,0xA0), _mm_movelh_ps(AB,AB));
- iD = _mm_add_ps(iD,_mm_mul_ps(_mm_shuffle_ps(C,C,0xF5), _mm_movehl_ps(AB,AB)));
- // iA = B*D#*C
- iA = _mm_mul_ps(_mm_shuffle_ps(B,B,0xA0), _mm_movelh_ps(DC,DC));
- iA = _mm_add_ps(iA,_mm_mul_ps(_mm_shuffle_ps(B,B,0xF5), _mm_movehl_ps(DC,DC)));
-
- // d = trace(AB*DC) = trace(A#*B*D#*C) [continue]
- d = _mm_add_ps(d, _mm_movehl_ps(d, d));
- d = _mm_add_ss(d, _mm_shuffle_ps(d, d, 1));
- d1 = _mm_mul_ss(dA,dD);
- d2 = _mm_mul_ss(dB,dC);
-
- // iD = D*|A| - C*A#*B
- iD = _mm_sub_ps(_mm_mul_ps(D,_mm_shuffle_ps(dA,dA,0)), iD);
-
- // iA = A*|D| - B*D#*C;
- iA = _mm_sub_ps(_mm_mul_ps(A,_mm_shuffle_ps(dD,dD,0)), iA);
-
- // det = |A|*|D| + |B|*|C| - trace(A#*B*D#*C)
- det = _mm_sub_ss(_mm_add_ss(d1,d2),d);
- rd = _mm_div_ss(_mm_set_ss(1.0f), det);
-
-// #ifdef ZERO_SINGULAR
-// rd = _mm_and_ps(_mm_cmpneq_ss(det,_mm_setzero_ps()), rd);
-// #endif
-
- // iB = D * (A#B)# = D*B#*A
- iB = _mm_mul_ps(D, _mm_shuffle_ps(AB,AB,0x33));
- iB = _mm_sub_ps(iB, _mm_mul_ps(_mm_shuffle_ps(D,D,0xB1), _mm_shuffle_ps(AB,AB,0x66)));
- // iC = A * (D#C)# = A*C#*D
- iC = _mm_mul_ps(A, _mm_shuffle_ps(DC,DC,0x33));
- iC = _mm_sub_ps(iC, _mm_mul_ps(_mm_shuffle_ps(A,A,0xB1), _mm_shuffle_ps(DC,DC,0x66)));
-
- rd = _mm_shuffle_ps(rd,rd,0);
- rd = _mm_xor_ps(rd, _mm_load_ps((float*)_Sign_PNNP));
-
- // iB = C*|B| - D*B#*A
- iB = _mm_sub_ps(_mm_mul_ps(C,_mm_shuffle_ps(dB,dB,0)), iB);
-
- // iC = B*|C| - A*C#*D;
- iC = _mm_sub_ps(_mm_mul_ps(B,_mm_shuffle_ps(dC,dC,0)), iC);
-
- // iX = iX / det
- iA = _mm_mul_ps(rd,iA);
- iB = _mm_mul_ps(rd,iB);
- iC = _mm_mul_ps(rd,iC);
- iD = _mm_mul_ps(rd,iD);
-
- Index res_stride = result.outerStride();
- float* res = result.data();
- pstoret<float, Packet4f, ResultAlignment>(res+0, _mm_shuffle_ps(iA,iB,0x77));
- pstoret<float, Packet4f, ResultAlignment>(res+res_stride, _mm_shuffle_ps(iA,iB,0x22));
- pstoret<float, Packet4f, ResultAlignment>(res+2*res_stride, _mm_shuffle_ps(iC,iD,0x77));
- pstoret<float, Packet4f, ResultAlignment>(res+3*res_stride, _mm_shuffle_ps(iC,iD,0x22));
- }
-
-};
-
-template<typename MatrixType, typename ResultType>
-struct compute_inverse_size4<Architecture::SSE, double, MatrixType, ResultType>
-{
- enum {
- MatrixAlignment = traits<MatrixType>::Alignment,
- ResultAlignment = traits<ResultType>::Alignment,
- StorageOrdersMatch = (MatrixType::Flags&RowMajorBit) == (ResultType::Flags&RowMajorBit)
- };
- typedef typename conditional<(MatrixType::Flags&LinearAccessBit),MatrixType const &,typename MatrixType::PlainObject>::type ActualMatrixType;
-
- static void run(const MatrixType& mat, ResultType& result)
- {
- ActualMatrixType matrix(mat);
- const __m128d _Sign_NP = _mm_castsi128_pd(_mm_set_epi32(0x0,0x0,0x80000000,0x0));
- const __m128d _Sign_PN = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
-
- // The inverse is calculated using "Divide and Conquer" technique. The
- // original matrix is divide into four 2x2 sub-matrices. Since each
- // register of the matrix holds two elements, the smaller matrices are
- // consisted of two registers. Hence we get a better locality of the
- // calculations.
-
- // the four sub-matrices
- __m128d A1, A2, B1, B2, C1, C2, D1, D2;
-
- if(StorageOrdersMatch)
- {
- A1 = matrix.template packet<MatrixAlignment>( 0); B1 = matrix.template packet<MatrixAlignment>( 2);
- A2 = matrix.template packet<MatrixAlignment>( 4); B2 = matrix.template packet<MatrixAlignment>( 6);
- C1 = matrix.template packet<MatrixAlignment>( 8); D1 = matrix.template packet<MatrixAlignment>(10);
- C2 = matrix.template packet<MatrixAlignment>(12); D2 = matrix.template packet<MatrixAlignment>(14);
- }
- else
- {
- __m128d tmp;
- A1 = matrix.template packet<MatrixAlignment>( 0); C1 = matrix.template packet<MatrixAlignment>( 2);
- A2 = matrix.template packet<MatrixAlignment>( 4); C2 = matrix.template packet<MatrixAlignment>( 6);
- tmp = A1;
- A1 = _mm_unpacklo_pd(A1,A2);
- A2 = _mm_unpackhi_pd(tmp,A2);
- tmp = C1;
- C1 = _mm_unpacklo_pd(C1,C2);
- C2 = _mm_unpackhi_pd(tmp,C2);
-
- B1 = matrix.template packet<MatrixAlignment>( 8); D1 = matrix.template packet<MatrixAlignment>(10);
- B2 = matrix.template packet<MatrixAlignment>(12); D2 = matrix.template packet<MatrixAlignment>(14);
- tmp = B1;
- B1 = _mm_unpacklo_pd(B1,B2);
- B2 = _mm_unpackhi_pd(tmp,B2);
- tmp = D1;
- D1 = _mm_unpacklo_pd(D1,D2);
- D2 = _mm_unpackhi_pd(tmp,D2);
- }
-
- __m128d iA1, iA2, iB1, iB2, iC1, iC2, iD1, iD2, // partial invese of the sub-matrices
- DC1, DC2, AB1, AB2;
- __m128d dA, dB, dC, dD; // determinant of the sub-matrices
- __m128d det, d1, d2, rd;
-
- // dA = |A|
- dA = _mm_shuffle_pd(A2, A2, 1);
- dA = _mm_mul_pd(A1, dA);
- dA = _mm_sub_sd(dA, _mm_shuffle_pd(dA,dA,3));
- // dB = |B|
- dB = _mm_shuffle_pd(B2, B2, 1);
- dB = _mm_mul_pd(B1, dB);
- dB = _mm_sub_sd(dB, _mm_shuffle_pd(dB,dB,3));
-
- // AB = A# * B
- AB1 = _mm_mul_pd(B1, _mm_shuffle_pd(A2,A2,3));
- AB2 = _mm_mul_pd(B2, _mm_shuffle_pd(A1,A1,0));
- AB1 = _mm_sub_pd(AB1, _mm_mul_pd(B2, _mm_shuffle_pd(A1,A1,3)));
- AB2 = _mm_sub_pd(AB2, _mm_mul_pd(B1, _mm_shuffle_pd(A2,A2,0)));
-
- // dC = |C|
- dC = _mm_shuffle_pd(C2, C2, 1);
- dC = _mm_mul_pd(C1, dC);
- dC = _mm_sub_sd(dC, _mm_shuffle_pd(dC,dC,3));
- // dD = |D|
- dD = _mm_shuffle_pd(D2, D2, 1);
- dD = _mm_mul_pd(D1, dD);
- dD = _mm_sub_sd(dD, _mm_shuffle_pd(dD,dD,3));
-
- // DC = D# * C
- DC1 = _mm_mul_pd(C1, _mm_shuffle_pd(D2,D2,3));
- DC2 = _mm_mul_pd(C2, _mm_shuffle_pd(D1,D1,0));
- DC1 = _mm_sub_pd(DC1, _mm_mul_pd(C2, _mm_shuffle_pd(D1,D1,3)));
- DC2 = _mm_sub_pd(DC2, _mm_mul_pd(C1, _mm_shuffle_pd(D2,D2,0)));
-
- // rd = trace(AB*DC) = trace(A#*B*D#*C)
- d1 = _mm_mul_pd(AB1, _mm_shuffle_pd(DC1, DC2, 0));
- d2 = _mm_mul_pd(AB2, _mm_shuffle_pd(DC1, DC2, 3));
- rd = _mm_add_pd(d1, d2);
- rd = _mm_add_sd(rd, _mm_shuffle_pd(rd, rd,3));
-
- // iD = C*A#*B
- iD1 = _mm_mul_pd(AB1, _mm_shuffle_pd(C1,C1,0));
- iD2 = _mm_mul_pd(AB1, _mm_shuffle_pd(C2,C2,0));
- iD1 = _mm_add_pd(iD1, _mm_mul_pd(AB2, _mm_shuffle_pd(C1,C1,3)));
- iD2 = _mm_add_pd(iD2, _mm_mul_pd(AB2, _mm_shuffle_pd(C2,C2,3)));
-
- // iA = B*D#*C
- iA1 = _mm_mul_pd(DC1, _mm_shuffle_pd(B1,B1,0));
- iA2 = _mm_mul_pd(DC1, _mm_shuffle_pd(B2,B2,0));
- iA1 = _mm_add_pd(iA1, _mm_mul_pd(DC2, _mm_shuffle_pd(B1,B1,3)));
- iA2 = _mm_add_pd(iA2, _mm_mul_pd(DC2, _mm_shuffle_pd(B2,B2,3)));
-
- // iD = D*|A| - C*A#*B
- dA = _mm_shuffle_pd(dA,dA,0);
- iD1 = _mm_sub_pd(_mm_mul_pd(D1, dA), iD1);
- iD2 = _mm_sub_pd(_mm_mul_pd(D2, dA), iD2);
-
- // iA = A*|D| - B*D#*C;
- dD = _mm_shuffle_pd(dD,dD,0);
- iA1 = _mm_sub_pd(_mm_mul_pd(A1, dD), iA1);
- iA2 = _mm_sub_pd(_mm_mul_pd(A2, dD), iA2);
-
- d1 = _mm_mul_sd(dA, dD);
- d2 = _mm_mul_sd(dB, dC);
-
- // iB = D * (A#B)# = D*B#*A
- iB1 = _mm_mul_pd(D1, _mm_shuffle_pd(AB2,AB1,1));
- iB2 = _mm_mul_pd(D2, _mm_shuffle_pd(AB2,AB1,1));
- iB1 = _mm_sub_pd(iB1, _mm_mul_pd(_mm_shuffle_pd(D1,D1,1), _mm_shuffle_pd(AB2,AB1,2)));
- iB2 = _mm_sub_pd(iB2, _mm_mul_pd(_mm_shuffle_pd(D2,D2,1), _mm_shuffle_pd(AB2,AB1,2)));
-
- // det = |A|*|D| + |B|*|C| - trace(A#*B*D#*C)
- det = _mm_add_sd(d1, d2);
- det = _mm_sub_sd(det, rd);
-
- // iC = A * (D#C)# = A*C#*D
- iC1 = _mm_mul_pd(A1, _mm_shuffle_pd(DC2,DC1,1));
- iC2 = _mm_mul_pd(A2, _mm_shuffle_pd(DC2,DC1,1));
- iC1 = _mm_sub_pd(iC1, _mm_mul_pd(_mm_shuffle_pd(A1,A1,1), _mm_shuffle_pd(DC2,DC1,2)));
- iC2 = _mm_sub_pd(iC2, _mm_mul_pd(_mm_shuffle_pd(A2,A2,1), _mm_shuffle_pd(DC2,DC1,2)));
-
- rd = _mm_div_sd(_mm_set_sd(1.0), det);
-// #ifdef ZERO_SINGULAR
-// rd = _mm_and_pd(_mm_cmpneq_sd(det,_mm_setzero_pd()), rd);
-// #endif
- rd = _mm_shuffle_pd(rd,rd,0);
-
- // iB = C*|B| - D*B#*A
- dB = _mm_shuffle_pd(dB,dB,0);
- iB1 = _mm_sub_pd(_mm_mul_pd(C1, dB), iB1);
- iB2 = _mm_sub_pd(_mm_mul_pd(C2, dB), iB2);
-
- d1 = _mm_xor_pd(rd, _Sign_PN);
- d2 = _mm_xor_pd(rd, _Sign_NP);
-
- // iC = B*|C| - A*C#*D;
- dC = _mm_shuffle_pd(dC,dC,0);
- iC1 = _mm_sub_pd(_mm_mul_pd(B1, dC), iC1);
- iC2 = _mm_sub_pd(_mm_mul_pd(B2, dC), iC2);
-
- Index res_stride = result.outerStride();
- double* res = result.data();
- pstoret<double, Packet2d, ResultAlignment>(res+0, _mm_mul_pd(_mm_shuffle_pd(iA2, iA1, 3), d1));
- pstoret<double, Packet2d, ResultAlignment>(res+res_stride, _mm_mul_pd(_mm_shuffle_pd(iA2, iA1, 0), d2));
- pstoret<double, Packet2d, ResultAlignment>(res+2, _mm_mul_pd(_mm_shuffle_pd(iB2, iB1, 3), d1));
- pstoret<double, Packet2d, ResultAlignment>(res+res_stride+2, _mm_mul_pd(_mm_shuffle_pd(iB2, iB1, 0), d2));
- pstoret<double, Packet2d, ResultAlignment>(res+2*res_stride, _mm_mul_pd(_mm_shuffle_pd(iC2, iC1, 3), d1));
- pstoret<double, Packet2d, ResultAlignment>(res+3*res_stride, _mm_mul_pd(_mm_shuffle_pd(iC2, iC1, 0), d2));
- pstoret<double, Packet2d, ResultAlignment>(res+2*res_stride+2,_mm_mul_pd(_mm_shuffle_pd(iD2, iD1, 3), d1));
- pstoret<double, Packet2d, ResultAlignment>(res+3*res_stride+2,_mm_mul_pd(_mm_shuffle_pd(iD2, iD1, 0), d2));
- }
-};
-
-} // end namespace internal
-
-} // end namespace Eigen
-
-#endif // EIGEN_INVERSE_SSE_H
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/ColPivHouseholderQR.h b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/ColPivHouseholderQR.h
index a7b47d5..9b677e9 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/ColPivHouseholderQR.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/ColPivHouseholderQR.h
@@ -17,6 +17,9 @@
template<typename _MatrixType> struct traits<ColPivHouseholderQR<_MatrixType> >
: traits<_MatrixType>
{
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
enum { Flags = 0 };
};
@@ -46,20 +49,19 @@
* \sa MatrixBase::colPivHouseholderQr()
*/
template<typename _MatrixType> class ColPivHouseholderQR
+ : public SolverBase<ColPivHouseholderQR<_MatrixType> >
{
public:
typedef _MatrixType MatrixType;
+ typedef SolverBase<ColPivHouseholderQR> Base;
+ friend class SolverBase<ColPivHouseholderQR>;
+
+ EIGEN_GENERIC_PUBLIC_INTERFACE(ColPivHouseholderQR)
enum {
- RowsAtCompileTime = MatrixType::RowsAtCompileTime,
- ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
- typedef typename MatrixType::Scalar Scalar;
- typedef typename MatrixType::RealScalar RealScalar;
- // FIXME should be int
- typedef typename MatrixType::StorageIndex StorageIndex;
typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType;
typedef PermutationMatrix<ColsAtCompileTime, MaxColsAtCompileTime> PermutationType;
typedef typename internal::plain_row_type<MatrixType, Index>::type IntRowVectorType;
@@ -156,6 +158,7 @@
computeInPlace();
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** This method finds a solution x to the equation Ax=b, where A is the matrix of which
* *this is the QR decomposition, if any exists.
*
@@ -172,11 +175,8 @@
*/
template<typename Rhs>
inline const Solve<ColPivHouseholderQR, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "ColPivHouseholderQR is not initialized.");
- return Solve<ColPivHouseholderQR, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
HouseholderSequenceType householderQ() const;
HouseholderSequenceType matrixQ() const
@@ -402,7 +402,7 @@
*/
RealScalar maxPivot() const { return m_maxpivot; }
- /** \brief Reports whether the QR factorization was succesful.
+ /** \brief Reports whether the QR factorization was successful.
*
* \note This function always returns \c Success. It is provided for compatibility
* with other factorization routines.
@@ -416,8 +416,10 @@
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
@@ -584,8 +586,6 @@
template<typename RhsType, typename DstType>
void ColPivHouseholderQR<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
- eigen_assert(rhs.rows() == rows());
-
const Index nonzero_pivots = nonzeroPivots();
if(nonzero_pivots == 0)
@@ -596,11 +596,7 @@
typename RhsType::PlainObject c(rhs);
- // Note that the matrix Q = H_0^* H_1^*... so its inverse is Q^* = (H_0 H_1 ...)^T
- c.applyOnTheLeft(householderSequence(m_qr, m_hCoeffs)
- .setLength(nonzero_pivots)
- .transpose()
- );
+ c.applyOnTheLeft(householderQ().setLength(nonzero_pivots).adjoint() );
m_qr.topLeftCorner(nonzero_pivots, nonzero_pivots)
.template triangularView<Upper>()
@@ -609,6 +605,31 @@
for(Index i = 0; i < nonzero_pivots; ++i) dst.row(m_colsPermutation.indices().coeff(i)) = c.row(i);
for(Index i = nonzero_pivots; i < cols(); ++i) dst.row(m_colsPermutation.indices().coeff(i)).setZero();
}
+
+template<typename _MatrixType>
+template<bool Conjugate, typename RhsType, typename DstType>
+void ColPivHouseholderQR<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
+ const Index nonzero_pivots = nonzeroPivots();
+
+ if(nonzero_pivots == 0)
+ {
+ dst.setZero();
+ return;
+ }
+
+ typename RhsType::PlainObject c(m_colsPermutation.transpose()*rhs);
+
+ m_qr.topLeftCorner(nonzero_pivots, nonzero_pivots)
+ .template triangularView<Upper>()
+ .transpose().template conjugateIf<Conjugate>()
+ .solveInPlace(c.topRows(nonzero_pivots));
+
+ dst.topRows(nonzero_pivots) = c.topRows(nonzero_pivots);
+ dst.bottomRows(rows()-nonzero_pivots).setZero();
+
+ dst.applyOnTheLeft(householderQ().setLength(nonzero_pivots).template conjugateIf<!Conjugate>() );
+}
#endif
namespace internal {
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/CompleteOrthogonalDecomposition.h b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/CompleteOrthogonalDecomposition.h
index 34c637b..486d337 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/CompleteOrthogonalDecomposition.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/CompleteOrthogonalDecomposition.h
@@ -16,6 +16,9 @@
template <typename _MatrixType>
struct traits<CompleteOrthogonalDecomposition<_MatrixType> >
: traits<_MatrixType> {
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
enum { Flags = 0 };
};
@@ -44,19 +47,21 @@
*
* \sa MatrixBase::completeOrthogonalDecomposition()
*/
-template <typename _MatrixType>
-class CompleteOrthogonalDecomposition {
+template <typename _MatrixType> class CompleteOrthogonalDecomposition
+ : public SolverBase<CompleteOrthogonalDecomposition<_MatrixType> >
+{
public:
typedef _MatrixType MatrixType;
+ typedef SolverBase<CompleteOrthogonalDecomposition> Base;
+
+ template<typename Derived>
+ friend struct internal::solve_assertion;
+
+ EIGEN_GENERIC_PUBLIC_INTERFACE(CompleteOrthogonalDecomposition)
enum {
- RowsAtCompileTime = MatrixType::RowsAtCompileTime,
- ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
- typedef typename MatrixType::Scalar Scalar;
- typedef typename MatrixType::RealScalar RealScalar;
- typedef typename MatrixType::StorageIndex StorageIndex;
typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType;
typedef PermutationMatrix<ColsAtCompileTime, MaxColsAtCompileTime>
PermutationType;
@@ -131,9 +136,9 @@
m_temp(matrix.cols())
{
computeInPlace();
- }
+ }
-
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** This method computes the minimum-norm solution X to a least squares
* problem \f[\mathrm{minimize} \|A X - B\|, \f] where \b A is the matrix of
* which \c *this is the complete orthogonal decomposition.
@@ -145,11 +150,8 @@
*/
template <typename Rhs>
inline const Solve<CompleteOrthogonalDecomposition, Rhs> solve(
- const MatrixBase<Rhs>& b) const {
- eigen_assert(m_cpqr.m_isInitialized &&
- "CompleteOrthogonalDecomposition is not initialized.");
- return Solve<CompleteOrthogonalDecomposition, Rhs>(*this, b.derived());
- }
+ const MatrixBase<Rhs>& b) const;
+ #endif
HouseholderSequenceType householderQ(void) const;
HouseholderSequenceType matrixQ(void) const { return m_cpqr.householderQ(); }
@@ -158,8 +160,8 @@
*/
MatrixType matrixZ() const {
MatrixType Z = MatrixType::Identity(m_cpqr.cols(), m_cpqr.cols());
- applyZAdjointOnTheLeftInPlace(Z);
- return Z.adjoint();
+ applyZOnTheLeftInPlace<false>(Z);
+ return Z;
}
/** \returns a reference to the matrix where the complete orthogonal
@@ -275,6 +277,7 @@
*/
inline const Inverse<CompleteOrthogonalDecomposition> pseudoInverse() const
{
+ eigen_assert(m_cpqr.m_isInitialized && "CompleteOrthogonalDecomposition is not initialized.");
return Inverse<CompleteOrthogonalDecomposition>(*this);
}
@@ -353,7 +356,7 @@
inline RealScalar maxPivot() const { return m_cpqr.maxPivot(); }
/** \brief Reports whether the complete orthogonal decomposition was
- * succesful.
+ * successful.
*
* \note This function always returns \c Success. It is provided for
* compatibility
@@ -367,7 +370,10 @@
#ifndef EIGEN_PARSED_BY_DOXYGEN
template <typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC void _solve_impl(const RhsType& rhs, DstType& dst) const;
+ void _solve_impl(const RhsType& rhs, DstType& dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
@@ -375,8 +381,22 @@
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
}
+ template<bool Transpose_, typename Rhs>
+ void _check_solve_assertion(const Rhs& b) const {
+ EIGEN_ONLY_USED_FOR_DEBUG(b);
+ eigen_assert(m_cpqr.m_isInitialized && "CompleteOrthogonalDecomposition is not initialized.");
+ eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "CompleteOrthogonalDecomposition::solve(): invalid number of rows of the right hand side matrix b");
+ }
+
void computeInPlace();
+ /** Overwrites \b rhs with \f$ \mathbf{Z} * \mathbf{rhs} \f$ or
+ * \f$ \mathbf{\overline Z} * \mathbf{rhs} \f$ if \c Conjugate
+ * is set to \c true.
+ */
+ template <bool Conjugate, typename Rhs>
+ void applyZOnTheLeftInPlace(Rhs& rhs) const;
+
/** Overwrites \b rhs with \f$ \mathbf{Z}^* * \mathbf{rhs} \f$.
*/
template <typename Rhs>
@@ -452,7 +472,7 @@
// Apply Z(k) to the first k rows of X_k
m_cpqr.m_qr.topRightCorner(k, cols - rank + 1)
.applyHouseholderOnTheRight(
- m_cpqr.m_qr.row(k).tail(cols - rank).transpose(), m_zCoeffs(k),
+ m_cpqr.m_qr.row(k).tail(cols - rank).adjoint(), m_zCoeffs(k),
&m_temp(0));
}
if (k != rank - 1) {
@@ -465,13 +485,35 @@
}
template <typename MatrixType>
+template <bool Conjugate, typename Rhs>
+void CompleteOrthogonalDecomposition<MatrixType>::applyZOnTheLeftInPlace(
+ Rhs& rhs) const {
+ const Index cols = this->cols();
+ const Index nrhs = rhs.cols();
+ const Index rank = this->rank();
+ Matrix<typename Rhs::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs));
+ for (Index k = rank-1; k >= 0; --k) {
+ if (k != rank - 1) {
+ rhs.row(k).swap(rhs.row(rank - 1));
+ }
+ rhs.middleRows(rank - 1, cols - rank + 1)
+ .applyHouseholderOnTheLeft(
+ matrixQTZ().row(k).tail(cols - rank).transpose().template conjugateIf<!Conjugate>(), zCoeffs().template conjugateIf<Conjugate>()(k),
+ &temp(0));
+ if (k != rank - 1) {
+ rhs.row(k).swap(rhs.row(rank - 1));
+ }
+ }
+}
+
+template <typename MatrixType>
template <typename Rhs>
void CompleteOrthogonalDecomposition<MatrixType>::applyZAdjointOnTheLeftInPlace(
Rhs& rhs) const {
const Index cols = this->cols();
const Index nrhs = rhs.cols();
const Index rank = this->rank();
- Matrix<typename MatrixType::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs));
+ Matrix<typename Rhs::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs));
for (Index k = 0; k < rank; ++k) {
if (k != rank - 1) {
rhs.row(k).swap(rhs.row(rank - 1));
@@ -491,8 +533,6 @@
template <typename RhsType, typename DstType>
void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl(
const RhsType& rhs, DstType& dst) const {
- eigen_assert(rhs.rows() == this->rows());
-
const Index rank = this->rank();
if (rank == 0) {
dst.setZero();
@@ -500,11 +540,8 @@
}
// Compute c = Q^* * rhs
- // Note that the matrix Q = H_0^* H_1^*... so its inverse is
- // Q^* = (H_0 H_1 ...)^T
typename RhsType::PlainObject c(rhs);
- c.applyOnTheLeft(
- householderSequence(matrixQTZ(), hCoeffs()).setLength(rank).transpose());
+ c.applyOnTheLeft(matrixQ().setLength(rank).adjoint());
// Solve T z = c(1:rank, :)
dst.topRows(rank) = matrixT()
@@ -523,10 +560,45 @@
// Undo permutation to get x = P^{-1} * y.
dst = colsPermutation() * dst;
}
+
+template<typename _MatrixType>
+template<bool Conjugate, typename RhsType, typename DstType>
+void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
+ const Index rank = this->rank();
+
+ if (rank == 0) {
+ dst.setZero();
+ return;
+ }
+
+ typename RhsType::PlainObject c(colsPermutation().transpose()*rhs);
+
+ if (rank < cols()) {
+ applyZOnTheLeftInPlace<!Conjugate>(c);
+ }
+
+ matrixT().topLeftCorner(rank, rank)
+ .template triangularView<Upper>()
+ .transpose().template conjugateIf<Conjugate>()
+ .solveInPlace(c.topRows(rank));
+
+ dst.topRows(rank) = c.topRows(rank);
+ dst.bottomRows(rows()-rank).setZero();
+
+ dst.applyOnTheLeft(householderQ().setLength(rank).template conjugateIf<!Conjugate>() );
+}
#endif
namespace internal {
+template<typename MatrixType>
+struct traits<Inverse<CompleteOrthogonalDecomposition<MatrixType> > >
+ : traits<typename Transpose<typename MatrixType::PlainObject>::PlainObject>
+{
+ enum { Flags = 0 };
+};
+
template<typename DstXprType, typename MatrixType>
struct Assignment<DstXprType, Inverse<CompleteOrthogonalDecomposition<MatrixType> >, internal::assign_op<typename DstXprType::Scalar,typename CompleteOrthogonalDecomposition<MatrixType>::Scalar>, Dense2Dense>
{
@@ -534,7 +606,8 @@
typedef Inverse<CodType> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename CodType::Scalar> &)
{
- dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.rows()));
+ typedef Matrix<typename CodType::Scalar, CodType::RowsAtCompileTime, CodType::RowsAtCompileTime, 0, CodType::MaxRowsAtCompileTime, CodType::MaxRowsAtCompileTime> IdentityMatrixType;
+ dst = src.nestedExpression().solve(IdentityMatrixType::Identity(src.cols(), src.cols()));
}
};
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/FullPivHouseholderQR.h b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/FullPivHouseholderQR.h
index e489bdd..d0664a1 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/FullPivHouseholderQR.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/FullPivHouseholderQR.h
@@ -18,6 +18,9 @@
template<typename _MatrixType> struct traits<FullPivHouseholderQR<_MatrixType> >
: traits<_MatrixType>
{
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
enum { Flags = 0 };
};
@@ -55,20 +58,19 @@
* \sa MatrixBase::fullPivHouseholderQr()
*/
template<typename _MatrixType> class FullPivHouseholderQR
+ : public SolverBase<FullPivHouseholderQR<_MatrixType> >
{
public:
typedef _MatrixType MatrixType;
+ typedef SolverBase<FullPivHouseholderQR> Base;
+ friend class SolverBase<FullPivHouseholderQR>;
+
+ EIGEN_GENERIC_PUBLIC_INTERFACE(FullPivHouseholderQR)
enum {
- RowsAtCompileTime = MatrixType::RowsAtCompileTime,
- ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
- typedef typename MatrixType::Scalar Scalar;
- typedef typename MatrixType::RealScalar RealScalar;
- // FIXME should be int
- typedef typename MatrixType::StorageIndex StorageIndex;
typedef internal::FullPivHouseholderQRMatrixQReturnType<MatrixType> MatrixQReturnType;
typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType;
typedef Matrix<StorageIndex, 1,
@@ -156,6 +158,7 @@
computeInPlace();
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** This method finds a solution x to the equation Ax=b, where A is the matrix of which
* \c *this is the QR decomposition.
*
@@ -173,11 +176,8 @@
*/
template<typename Rhs>
inline const Solve<FullPivHouseholderQR, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "FullPivHouseholderQR is not initialized.");
- return Solve<FullPivHouseholderQR, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
/** \returns Expression object representing the matrix Q
*/
@@ -392,22 +392,24 @@
* diagonal coefficient of U.
*/
RealScalar maxPivot() const { return m_maxpivot; }
-
+
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
-
+
static void check_template_parameters()
{
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
}
-
+
void computeInPlace();
-
+
MatrixType m_qr;
HCoeffsType m_hCoeffs;
IntDiagSizeVectorType m_rows_transpositions;
@@ -499,15 +501,15 @@
m_nonzero_pivots = k;
for(Index i = k; i < size; i++)
{
- m_rows_transpositions.coeffRef(i) = i;
- m_cols_transpositions.coeffRef(i) = i;
+ m_rows_transpositions.coeffRef(i) = internal::convert_index<StorageIndex>(i);
+ m_cols_transpositions.coeffRef(i) = internal::convert_index<StorageIndex>(i);
m_hCoeffs.coeffRef(i) = Scalar(0);
}
break;
}
- m_rows_transpositions.coeffRef(k) = row_of_biggest_in_corner;
- m_cols_transpositions.coeffRef(k) = col_of_biggest_in_corner;
+ m_rows_transpositions.coeffRef(k) = internal::convert_index<StorageIndex>(row_of_biggest_in_corner);
+ m_cols_transpositions.coeffRef(k) = internal::convert_index<StorageIndex>(col_of_biggest_in_corner);
if(k != row_of_biggest_in_corner) {
m_qr.row(k).tail(cols-k).swap(m_qr.row(row_of_biggest_in_corner).tail(cols-k));
++number_of_transpositions;
@@ -541,7 +543,6 @@
template<typename RhsType, typename DstType>
void FullPivHouseholderQR<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
- eigen_assert(rhs.rows() == rows());
const Index l_rank = rank();
// FIXME introduce nonzeroPivots() and use it here. and more generally,
@@ -554,7 +555,7 @@
typename RhsType::PlainObject c(rhs);
- Matrix<Scalar,1,RhsType::ColsAtCompileTime> temp(rhs.cols());
+ Matrix<typename RhsType::Scalar,1,RhsType::ColsAtCompileTime> temp(rhs.cols());
for (Index k = 0; k < l_rank; ++k)
{
Index remainingSize = rows()-k;
@@ -571,6 +572,42 @@
for(Index i = 0; i < l_rank; ++i) dst.row(m_cols_permutation.indices().coeff(i)) = c.row(i);
for(Index i = l_rank; i < cols(); ++i) dst.row(m_cols_permutation.indices().coeff(i)).setZero();
}
+
+template<typename _MatrixType>
+template<bool Conjugate, typename RhsType, typename DstType>
+void FullPivHouseholderQR<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
+ const Index l_rank = rank();
+
+ if(l_rank == 0)
+ {
+ dst.setZero();
+ return;
+ }
+
+ typename RhsType::PlainObject c(m_cols_permutation.transpose()*rhs);
+
+ m_qr.topLeftCorner(l_rank, l_rank)
+ .template triangularView<Upper>()
+ .transpose().template conjugateIf<Conjugate>()
+ .solveInPlace(c.topRows(l_rank));
+
+ dst.topRows(l_rank) = c.topRows(l_rank);
+ dst.bottomRows(rows()-l_rank).setZero();
+
+ Matrix<Scalar, 1, DstType::ColsAtCompileTime> temp(dst.cols());
+ const Index size = (std::min)(rows(), cols());
+ for (Index k = size-1; k >= 0; --k)
+ {
+ Index remainingSize = rows()-k;
+
+ dst.bottomRightCorner(remainingSize, dst.cols())
+ .applyHouseholderOnTheLeft(m_qr.col(k).tail(remainingSize-1).template conjugateIf<!Conjugate>(),
+ m_hCoeffs.template conjugateIf<Conjugate>().coeff(k), &temp.coeffRef(0));
+
+ dst.row(k).swap(dst.row(m_rows_transpositions.coeff(k)));
+ }
+}
#endif
namespace internal {
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/HouseholderQR.h b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/HouseholderQR.h
index 3513d99..801739f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/QR/HouseholderQR.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/QR/HouseholderQR.h
@@ -14,6 +14,18 @@
namespace Eigen {
+namespace internal {
+template<typename _MatrixType> struct traits<HouseholderQR<_MatrixType> >
+ : traits<_MatrixType>
+{
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
+ enum { Flags = 0 };
+};
+
+} // end namespace internal
+
/** \ingroup QR_Module
*
*
@@ -42,20 +54,19 @@
* \sa MatrixBase::householderQr()
*/
template<typename _MatrixType> class HouseholderQR
+ : public SolverBase<HouseholderQR<_MatrixType> >
{
public:
typedef _MatrixType MatrixType;
+ typedef SolverBase<HouseholderQR> Base;
+ friend class SolverBase<HouseholderQR>;
+
+ EIGEN_GENERIC_PUBLIC_INTERFACE(HouseholderQR)
enum {
- RowsAtCompileTime = MatrixType::RowsAtCompileTime,
- ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
- typedef typename MatrixType::Scalar Scalar;
- typedef typename MatrixType::RealScalar RealScalar;
- // FIXME should be int
- typedef typename MatrixType::StorageIndex StorageIndex;
typedef Matrix<Scalar, RowsAtCompileTime, RowsAtCompileTime, (MatrixType::Flags&RowMajorBit) ? RowMajor : ColMajor, MaxRowsAtCompileTime, MaxRowsAtCompileTime> MatrixQType;
typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType;
typedef typename internal::plain_row_type<MatrixType>::type RowVectorType;
@@ -121,6 +132,7 @@
computeInPlace();
}
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** This method finds a solution x to the equation Ax=b, where A is the matrix of which
* *this is the QR decomposition, if any exists.
*
@@ -137,11 +149,8 @@
*/
template<typename Rhs>
inline const Solve<HouseholderQR, Rhs>
- solve(const MatrixBase<Rhs>& b) const
- {
- eigen_assert(m_isInitialized && "HouseholderQR is not initialized.");
- return Solve<HouseholderQR, Rhs>(*this, b.derived());
- }
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
/** This method returns an expression of the unitary matrix Q as a sequence of Householder transformations.
*
@@ -204,28 +213,30 @@
inline Index rows() const { return m_qr.rows(); }
inline Index cols() const { return m_qr.cols(); }
-
+
/** \returns a const reference to the vector of Householder coefficients used to represent the factor \c Q.
*
* For advanced uses only.
*/
const HCoeffsType& hCoeffs() const { return m_hCoeffs; }
-
+
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
-
+
static void check_template_parameters()
{
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
}
void computeInPlace();
-
+
MatrixType m_qr;
HCoeffsType m_hCoeffs;
RowVectorType m_temp;
@@ -292,7 +303,7 @@
bool InnerStrideIsOne = (MatrixQR::InnerStrideAtCompileTime == 1 && HCoeffs::InnerStrideAtCompileTime == 1)>
struct householder_qr_inplace_blocked
{
- // This is specialized for MKL-supported Scalar types in HouseholderQR_MKL.h
+ // This is specialized for LAPACK-supported Scalar types in HouseholderQR_LAPACKE.h
static void run(MatrixQR& mat, HCoeffs& hCoeffs, Index maxBlockSize=32,
typename MatrixQR::Scalar* tempData = 0)
{
@@ -350,15 +361,10 @@
void HouseholderQR<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
const Index rank = (std::min)(rows(), cols());
- eigen_assert(rhs.rows() == rows());
typename RhsType::PlainObject c(rhs);
- // Note that the matrix Q = H_0^* H_1^*... so its inverse is Q^* = (H_0 H_1 ...)^T
- c.applyOnTheLeft(householderSequence(
- m_qr.leftCols(rank),
- m_hCoeffs.head(rank)).transpose()
- );
+ c.applyOnTheLeft(householderQ().setLength(rank).adjoint() );
m_qr.topLeftCorner(rank, rank)
.template triangularView<Upper>()
@@ -367,6 +373,25 @@
dst.topRows(rank) = c.topRows(rank);
dst.bottomRows(cols()-rank).setZero();
}
+
+template<typename _MatrixType>
+template<bool Conjugate, typename RhsType, typename DstType>
+void HouseholderQR<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
+ const Index rank = (std::min)(rows(), cols());
+
+ typename RhsType::PlainObject c(rhs);
+
+ m_qr.topLeftCorner(rank, rank)
+ .template triangularView<Upper>()
+ .transpose().template conjugateIf<Conjugate>()
+ .solveInPlace(c.topRows(rank));
+
+ dst.topRows(rank) = c.topRows(rank);
+ dst.bottomRows(rows()-rank).setZero();
+
+ dst.applyOnTheLeft(householderQ().setLength(rank).template conjugateIf<!Conjugate>() );
+}
#endif
/** Performs the QR factorization of the given matrix \a matrix. The result of
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/BDCSVD.h b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/BDCSVD.h
index 1134d66..17f8e44 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/BDCSVD.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/BDCSVD.h
@@ -22,6 +22,11 @@
// #define EIGEN_BDCSVD_DEBUG_VERBOSE
// #define EIGEN_BDCSVD_SANITY_CHECKS
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+#undef eigen_internal_assert
+#define eigen_internal_assert(X) assert(X);
+#endif
+
namespace Eigen {
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
@@ -34,6 +39,7 @@
template<typename _MatrixType>
struct traits<BDCSVD<_MatrixType> >
+ : traits<_MatrixType>
{
typedef _MatrixType MatrixType;
};
@@ -57,7 +63,7 @@
* recommended and can several order of magnitude faster.
*
* \warning this algorithm is unlikely to provide accurate result when compiled with unsafe math optimizations.
- * For instance, this concerns Intel's compiler (ICC), which perfroms such optimization by default unless
+ * For instance, this concerns Intel's compiler (ICC), which performs such optimization by default unless
* you compile with the \c -fp-model \c precise option. Likewise, the \c -ffast-math option of GCC or clang will
* significantly degrade the accuracy.
*
@@ -105,7 +111,7 @@
* The default constructor is useful in cases in which the user intends to
* perform decompositions via BDCSVD::compute(const MatrixType&).
*/
- BDCSVD() : m_algoswap(16), m_numIters(0)
+ BDCSVD() : m_algoswap(16), m_isTranspose(false), m_compU(false), m_compV(false), m_numIters(0)
{}
@@ -202,6 +208,7 @@
using Base::m_computeThinV;
using Base::m_matrixU;
using Base::m_matrixV;
+ using Base::m_info;
using Base::m_isInitialized;
using Base::m_nonzeroSingularValues;
@@ -212,7 +219,7 @@
// Method to allocate and initialize matrix and attributes
template<typename MatrixType>
-void BDCSVD<MatrixType>::allocate(Index rows, Index cols, unsigned int computationOptions)
+void BDCSVD<MatrixType>::allocate(Eigen::Index rows, Eigen::Index cols, unsigned int computationOptions)
{
m_isTranspose = (cols > rows);
@@ -250,16 +257,25 @@
{
// FIXME this line involves temporaries
JacobiSVD<MatrixType> jsvd(matrix,computationOptions);
- if(computeU()) m_matrixU = jsvd.matrixU();
- if(computeV()) m_matrixV = jsvd.matrixV();
- m_singularValues = jsvd.singularValues();
- m_nonzeroSingularValues = jsvd.nonzeroSingularValues();
m_isInitialized = true;
+ m_info = jsvd.info();
+ if (m_info == Success || m_info == NoConvergence) {
+ if(computeU()) m_matrixU = jsvd.matrixU();
+ if(computeV()) m_matrixV = jsvd.matrixV();
+ m_singularValues = jsvd.singularValues();
+ m_nonzeroSingularValues = jsvd.nonzeroSingularValues();
+ }
return *this;
}
//**** step 0 - Copy the input matrix and apply scaling to reduce over/under-flows
- RealScalar scale = matrix.cwiseAbs().maxCoeff();
+ RealScalar scale = matrix.cwiseAbs().template maxCoeff<PropagateNaN>();
+ if (!(numext::isfinite)(scale)) {
+ m_isInitialized = true;
+ m_info = InvalidInput;
+ return *this;
+ }
+
if(scale==Literal(0)) scale = Literal(1);
MatrixX copy;
if (m_isTranspose) copy = matrix.adjoint()/scale;
@@ -276,7 +292,11 @@
m_computed.topRows(m_diagSize) = bid.bidiagonal().toDenseMatrix().transpose();
m_computed.template bottomRows<1>().setZero();
divide(0, m_diagSize - 1, 0, 0, 0);
-
+ if (m_info != Success && m_info != NoConvergence) {
+ m_isInitialized = true;
+ return *this;
+ }
+
//**** step 3 - Copy singular values and vectors
for (int i=0; i<m_diagSize; i++)
{
@@ -388,7 +408,7 @@
//@param shift : Each time one takes the left submatrix, one must add 1 to the shift. Why? Because! We actually want the last column of the U submatrix
// to become the first column (*coeff) and to shift all the other columns to the right. There are more details on the reference paper.
template<typename MatrixType>
-void BDCSVD<MatrixType>::divide (Index firstCol, Index lastCol, Index firstRowW, Index firstColW, Index shift)
+void BDCSVD<MatrixType>::divide(Eigen::Index firstCol, Eigen::Index lastCol, Eigen::Index firstRowW, Eigen::Index firstColW, Eigen::Index shift)
{
// requires rows = cols + 1;
using std::pow;
@@ -408,6 +428,8 @@
{
// FIXME this line involves temporaries
JacobiSVD<MatrixXr> b(m_computed.block(firstCol, firstCol, n + 1, n), ComputeFullU | (m_compV ? ComputeFullV : 0));
+ m_info = b.info();
+ if (m_info != Success && m_info != NoConvergence) return;
if (m_compU)
m_naiveU.block(firstCol, firstCol, n + 1, n + 1).real() = b.matrixU();
else
@@ -427,7 +449,9 @@
// and the divide of the right submatrice reads one column of the left submatrice. That's why we need to treat the
// right submatrix before the left one.
divide(k + 1 + firstCol, lastCol, k + 1 + firstRowW, k + 1 + firstColW, shift);
+ if (m_info != Success && m_info != NoConvergence) return;
divide(firstCol, k - 1 + firstCol, firstRowW, firstColW + 1, shift + 1);
+ if (m_info != Success && m_info != NoConvergence) return;
if (m_compU)
{
@@ -568,7 +592,7 @@
// handling of round-off errors, be consistent in ordering
// For instance, to solve the secular equation using FMM, see http://www.stat.uchicago.edu/~lekheng/courses/302/classics/greengard-rokhlin.pdf
template <typename MatrixType>
-void BDCSVD<MatrixType>::computeSVDofM(Index firstCol, Index n, MatrixXr& U, VectorType& singVals, MatrixXr& V)
+void BDCSVD<MatrixType>::computeSVDofM(Eigen::Index firstCol, Eigen::Index n, MatrixXr& U, VectorType& singVals, MatrixXr& V)
{
const RealScalar considerZero = (std::numeric_limits<RealScalar>::min)();
using std::abs;
@@ -591,7 +615,7 @@
// but others are interleaved and we must ignore them at this stage.
// To this end, let's compute a permutation skipping them:
Index actual_n = n;
- while(actual_n>1 && diag(actual_n-1)==Literal(0)) --actual_n;
+ while(actual_n>1 && diag(actual_n-1)==Literal(0)) {--actual_n; eigen_internal_assert(col0(actual_n)==Literal(0)); }
Index m = 0; // size of the deflated problem
for(Index k=0;k<actual_n;++k)
if(abs(col0(k))>considerZero)
@@ -618,13 +642,11 @@
std::cout << " shift: " << shifts.transpose() << "\n";
{
- Index actual_n = n;
- while(actual_n>1 && abs(col0(actual_n-1))<considerZero) --actual_n;
std::cout << "\n\n mus: " << mus.head(actual_n).transpose() << "\n\n";
std::cout << " check1 (expect0) : " << ((singVals.array()-(shifts+mus)) / singVals.array()).head(actual_n).transpose() << "\n\n";
+ assert((((singVals.array()-(shifts+mus)) / singVals.array()).head(actual_n) >= 0).all());
std::cout << " check2 (>0) : " << ((singVals.array()-diag) / singVals.array()).head(actual_n).transpose() << "\n\n";
- std::cout << " check3 (>0) : " << ((diag.segment(1,actual_n-1)-singVals.head(actual_n-1).array()) / singVals.head(actual_n-1).array()).transpose() << "\n\n\n";
- std::cout << " check4 (>0) : " << ((singVals.segment(1,actual_n-1)-singVals.head(actual_n-1))).transpose() << "\n\n\n";
+ assert((((singVals.array()-diag) / singVals.array()).head(actual_n) >= 0).all());
}
#endif
@@ -652,13 +674,13 @@
#endif
#ifdef EIGEN_BDCSVD_SANITY_CHECKS
- assert(U.allFinite());
- assert(V.allFinite());
- assert((U.transpose() * U - MatrixXr(MatrixXr::Identity(U.cols(),U.cols()))).norm() < 1e-14 * n);
- assert((V.transpose() * V - MatrixXr(MatrixXr::Identity(V.cols(),V.cols()))).norm() < 1e-14 * n);
assert(m_naiveU.allFinite());
assert(m_naiveV.allFinite());
assert(m_computed.allFinite());
+ assert(U.allFinite());
+ assert(V.allFinite());
+// assert((U.transpose() * U - MatrixXr(MatrixXr::Identity(U.cols(),U.cols()))).norm() < 100*NumTraits<RealScalar>::epsilon() * n);
+// assert((V.transpose() * V - MatrixXr(MatrixXr::Identity(V.cols(),V.cols()))).norm() < 100*NumTraits<RealScalar>::epsilon() * n);
#endif
// Because of deflation, the singular values might not be completely sorted.
@@ -673,6 +695,15 @@
if(m_compV) V.col(i).swap(V.col(i+1));
}
}
+
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ {
+ bool singular_values_sorted = (((singVals.segment(1,actual_n-1)-singVals.head(actual_n-1))).array() >= 0).all();
+ if(!singular_values_sorted)
+ std::cout << "Singular values are not sorted: " << singVals.segment(1,actual_n).transpose() << "\n";
+ assert(singular_values_sorted);
+ }
+#endif
// Reverse order so that singular values in increased order
// Because of deflation, the zeros singular-values are already at the end
@@ -749,25 +780,43 @@
RealScalar mid = left + (right-left) / Literal(2);
RealScalar fMid = secularEq(mid, col0, diag, perm, diag, Literal(0));
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
- std::cout << right-left << "\n";
- std::cout << "fMid = " << fMid << " " << secularEq(mid-left, col0, diag, perm, diag-left, left) << " " << secularEq(mid-right, col0, diag, perm, diag-right, right) << "\n";
- std::cout << " = " << secularEq(0.1*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.2*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.3*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.4*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.49*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.5*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.51*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.6*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.7*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.8*(left+right), col0, diag, perm, diag, 0)
- << " " << secularEq(0.9*(left+right), col0, diag, perm, diag, 0) << "\n";
+ std::cout << "right-left = " << right-left << "\n";
+// std::cout << "fMid = " << fMid << " " << secularEq(mid-left, col0, diag, perm, ArrayXr(diag-left), left)
+// << " " << secularEq(mid-right, col0, diag, perm, ArrayXr(diag-right), right) << "\n";
+ std::cout << " = " << secularEq(left+RealScalar(0.000001)*(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.1) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.2) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.3) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.4) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.49) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.5) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.51) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.6) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.7) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.8) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.9) *(right-left), col0, diag, perm, diag, 0)
+ << " " << secularEq(left+RealScalar(0.999999)*(right-left), col0, diag, perm, diag, 0) << "\n";
#endif
RealScalar shift = (k == actual_n-1 || fMid > Literal(0)) ? left : right;
// measure everything relative to shift
Map<ArrayXr> diagShifted(m_workspace.data()+4*n, n);
diagShifted = diag - shift;
+
+ if(k!=actual_n-1)
+ {
+ // check that after the shift, f(mid) is still negative:
+ RealScalar midShifted = (right - left) / RealScalar(2);
+ if(shift==right)
+ midShifted = -midShifted;
+ RealScalar fMidShifted = secularEq(midShifted, col0, diag, perm, diagShifted, shift);
+ if(fMidShifted>0)
+ {
+ // fMid was erroneous, fix it:
+ shift = fMidShifted > Literal(0) ? left : right;
+ diagShifted = diag - shift;
+ }
+ }
// initial guess
RealScalar muPrev, muCur;
@@ -804,13 +853,16 @@
// And find mu such that f(mu)==0:
RealScalar muZero = -a/b;
RealScalar fZero = secularEq(muZero, col0, diag, perm, diagShifted, shift);
+
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ assert((numext::isfinite)(fZero));
+#endif
muPrev = muCur;
fPrev = fCur;
muCur = muZero;
fCur = fZero;
-
if (shift == left && (muCur < Literal(0) || muCur > right - left)) useBisection = true;
if (shift == right && (muCur < -(right - left) || muCur > Literal(0))) useBisection = true;
if (abs(fCur)>abs(fPrev)) useBisection = true;
@@ -843,44 +895,82 @@
else
rightShifted = -(std::numeric_limits<RealScalar>::min)();
}
-
- RealScalar fLeft = secularEq(leftShifted, col0, diag, perm, diagShifted, shift);
-#if defined EIGEN_INTERNAL_DEBUGGING || defined EIGEN_BDCSVD_DEBUG_VERBOSE
+ RealScalar fLeft = secularEq(leftShifted, col0, diag, perm, diagShifted, shift);
+ eigen_internal_assert(fLeft<Literal(0));
+
+#if defined EIGEN_INTERNAL_DEBUGGING || defined EIGEN_BDCSVD_SANITY_CHECKS
RealScalar fRight = secularEq(rightShifted, col0, diag, perm, diagShifted, shift);
#endif
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ if(!(numext::isfinite)(fLeft))
+ std::cout << "f(" << leftShifted << ") =" << fLeft << " ; " << left << " " << shift << " " << right << "\n";
+ assert((numext::isfinite)(fLeft));
+
+ if(!(numext::isfinite)(fRight))
+ std::cout << "f(" << rightShifted << ") =" << fRight << " ; " << left << " " << shift << " " << right << "\n";
+ // assert((numext::isfinite)(fRight));
+#endif
+
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
if(!(fLeft * fRight<0))
{
- std::cout << "fLeft: " << leftShifted << " - " << diagShifted.head(10).transpose() << "\n ; " << bool(left==shift) << " " << (left-shift) << "\n";
- std::cout << k << " : " << fLeft << " * " << fRight << " == " << fLeft * fRight << " ; " << left << " - " << right << " -> " << leftShifted << " " << rightShifted << " shift=" << shift << "\n";
+ std::cout << "f(leftShifted) using leftShifted=" << leftShifted << " ; diagShifted(1:10):" << diagShifted.head(10).transpose() << "\n ; "
+ << "left==shift=" << bool(left==shift) << " ; left-shift = " << (left-shift) << "\n";
+ std::cout << "k=" << k << ", " << fLeft << " * " << fRight << " == " << fLeft * fRight << " ; "
+ << "[" << left << " .. " << right << "] -> [" << leftShifted << " " << rightShifted << "], shift=" << shift
+ << " , f(right)=" << secularEq(0, col0, diag, perm, diagShifted, shift)
+ << " == " << secularEq(right, col0, diag, perm, diag, 0) << " == " << fRight << "\n";
}
#endif
eigen_internal_assert(fLeft * fRight < Literal(0));
-
- while (rightShifted - leftShifted > Literal(2) * NumTraits<RealScalar>::epsilon() * numext::maxi<RealScalar>(abs(leftShifted), abs(rightShifted)))
- {
- RealScalar midShifted = (leftShifted + rightShifted) / Literal(2);
- fMid = secularEq(midShifted, col0, diag, perm, diagShifted, shift);
- if (fLeft * fMid < Literal(0))
- {
- rightShifted = midShifted;
- }
- else
- {
- leftShifted = midShifted;
- fLeft = fMid;
- }
- }
- muCur = (leftShifted + rightShifted) / Literal(2);
+ if(fLeft<Literal(0))
+ {
+ while (rightShifted - leftShifted > Literal(2) * NumTraits<RealScalar>::epsilon() * numext::maxi<RealScalar>(abs(leftShifted), abs(rightShifted)))
+ {
+ RealScalar midShifted = (leftShifted + rightShifted) / Literal(2);
+ fMid = secularEq(midShifted, col0, diag, perm, diagShifted, shift);
+ eigen_internal_assert((numext::isfinite)(fMid));
+
+ if (fLeft * fMid < Literal(0))
+ {
+ rightShifted = midShifted;
+ }
+ else
+ {
+ leftShifted = midShifted;
+ fLeft = fMid;
+ }
+ }
+ muCur = (leftShifted + rightShifted) / Literal(2);
+ }
+ else
+ {
+ // We have a problem as shifting on the left or right give either a positive or negative value
+ // at the middle of [left,right]...
+ // Instead fo abbording or entering an infinite loop,
+ // let's just use the middle as the estimated zero-crossing:
+ muCur = (right - left) * RealScalar(0.5);
+ if(shift == right)
+ muCur = -muCur;
+ }
}
singVals[k] = shift + muCur;
shifts[k] = shift;
mus[k] = muCur;
+#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
+ if(k+1<n)
+ std::cout << "found " << singVals[k] << " == " << shift << " + " << muCur << " from " << diag(k) << " .. " << diag(k+1) << "\n";
+#endif
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ assert(k==0 || singVals[k]>=singVals[k-1]);
+ assert(singVals[k]>=diag(k));
+#endif
+
// perturb singular value slightly if it equals diagonal entry to avoid division by zero later
// (deflation is supposed to avoid this from happening)
// - this does no seem to be necessary anymore -
@@ -904,7 +994,7 @@
zhat.setZero();
return;
}
- Index last = perm(m-1);
+ Index lastIdx = perm(m-1);
// The offset permits to skip deflated entries while computing zhat
for (Index k = 0; k < n; ++k)
{
@@ -914,27 +1004,58 @@
{
// see equation (3.6)
RealScalar dk = diag(k);
- RealScalar prod = (singVals(last) + dk) * (mus(last) + (shifts(last) - dk));
+ RealScalar prod = (singVals(lastIdx) + dk) * (mus(lastIdx) + (shifts(lastIdx) - dk));
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ if(prod<0) {
+ std::cout << "k = " << k << " ; z(k)=" << col0(k) << ", diag(k)=" << dk << "\n";
+ std::cout << "prod = " << "(" << singVals(lastIdx) << " + " << dk << ") * (" << mus(lastIdx) << " + (" << shifts(lastIdx) << " - " << dk << "))" << "\n";
+ std::cout << " = " << singVals(lastIdx) + dk << " * " << mus(lastIdx) + (shifts(lastIdx) - dk) << "\n";
+ }
+ assert(prod>=0);
+#endif
for(Index l = 0; l<m; ++l)
{
Index i = perm(l);
if(i!=k)
{
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ if(i>=k && (l==0 || l-1>=m))
+ {
+ std::cout << "Error in perturbCol0\n";
+ std::cout << " " << k << "/" << n << " " << l << "/" << m << " " << i << "/" << n << " ; " << col0(k) << " " << diag(k) << " " << "\n";
+ std::cout << " " <<diag(i) << "\n";
+ Index j = (i<k /*|| l==0*/) ? i : perm(l-1);
+ std::cout << " " << "j=" << j << "\n";
+ }
+#endif
Index j = i<k ? i : perm(l-1);
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ if(!(dk!=Literal(0) || diag(i)!=Literal(0)))
+ {
+ std::cout << "k=" << k << ", i=" << i << ", l=" << l << ", perm.size()=" << perm.size() << "\n";
+ }
+ assert(dk!=Literal(0) || diag(i)!=Literal(0));
+#endif
prod *= ((singVals(j)+dk) / ((diag(i)+dk))) * ((mus(j)+(shifts(j)-dk)) / ((diag(i)-dk)));
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ assert(prod>=0);
+#endif
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
- if(i!=k && std::abs(((singVals(j)+dk)*(mus(j)+(shifts(j)-dk)))/((diag(i)+dk)*(diag(i)-dk)) - 1) > 0.9 )
+ if(i!=k && numext::abs(((singVals(j)+dk)*(mus(j)+(shifts(j)-dk)))/((diag(i)+dk)*(diag(i)-dk)) - 1) > 0.9 )
std::cout << " " << ((singVals(j)+dk)*(mus(j)+(shifts(j)-dk)))/((diag(i)+dk)*(diag(i)-dk)) << " == (" << (singVals(j)+dk) << " * " << (mus(j)+(shifts(j)-dk))
<< ") / (" << (diag(i)+dk) << " * " << (diag(i)-dk) << ")\n";
#endif
}
}
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
- std::cout << "zhat(" << k << ") = sqrt( " << prod << ") ; " << (singVals(last) + dk) << " * " << mus(last) + shifts(last) << " - " << dk << "\n";
+ std::cout << "zhat(" << k << ") = sqrt( " << prod << ") ; " << (singVals(lastIdx) + dk) << " * " << mus(lastIdx) + shifts(lastIdx) << " - " << dk << "\n";
#endif
RealScalar tmp = sqrt(prod);
- zhat(k) = col0(k) > Literal(0) ? tmp : -tmp;
+#ifdef EIGEN_BDCSVD_SANITY_CHECKS
+ assert((numext::isfinite)(tmp));
+#endif
+ zhat(k) = col0(k) > Literal(0) ? RealScalar(tmp) : RealScalar(-tmp);
}
}
}
@@ -987,7 +1108,7 @@
// i >= 1, di almost null and zi non null.
// We use a rotation to zero out zi applied to the left of M
template <typename MatrixType>
-void BDCSVD<MatrixType>::deflation43(Index firstCol, Index shift, Index i, Index size)
+void BDCSVD<MatrixType>::deflation43(Eigen::Index firstCol, Eigen::Index shift, Eigen::Index i, Eigen::Index size)
{
using std::abs;
using std::sqrt;
@@ -1016,7 +1137,7 @@
// We apply two rotations to have zj = 0;
// TODO deflation44 is still broken and not properly tested
template <typename MatrixType>
-void BDCSVD<MatrixType>::deflation44(Index firstColu , Index firstColm, Index firstRowW, Index firstColW, Index i, Index j, Index size)
+void BDCSVD<MatrixType>::deflation44(Eigen::Index firstColu , Eigen::Index firstColm, Eigen::Index firstRowW, Eigen::Index firstColW, Eigen::Index i, Eigen::Index j, Eigen::Index size)
{
using std::abs;
using std::sqrt;
@@ -1043,7 +1164,7 @@
}
c/=r;
s/=r;
- m_computed(firstColm + i, firstColm) = r;
+ m_computed(firstColm + i, firstColm) = r;
m_computed(firstColm + j, firstColm + j) = m_computed(firstColm + i, firstColm + i);
m_computed(firstColm + j, firstColm) = Literal(0);
@@ -1056,7 +1177,7 @@
// acts on block from (firstCol+shift, firstCol+shift) to (lastCol+shift, lastCol+shift) [inclusive]
template <typename MatrixType>
-void BDCSVD<MatrixType>::deflation(Index firstCol, Index lastCol, Index k, Index firstRowW, Index firstColW, Index shift)
+void BDCSVD<MatrixType>::deflation(Eigen::Index firstCol, Eigen::Index lastCol, Eigen::Index k, Eigen::Index firstRowW, Eigen::Index firstColW, Eigen::Index shift)
{
using std::sqrt;
using std::abs;
@@ -1117,6 +1238,7 @@
#endif
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
std::cout << "to be sorted: " << diag.transpose() << "\n\n";
+ std::cout << " : " << col0.transpose() << "\n\n";
#endif
{
// Check for total deflation
@@ -1207,7 +1329,7 @@
if( (diag(i) - diag(i-1)) < NumTraits<RealScalar>::epsilon()*maxDiag )
{
#ifdef EIGEN_BDCSVD_DEBUG_VERBOSE
- std::cout << "deflation 4.4 with i = " << i << " because " << (diag(i) - diag(i-1)) << " < " << NumTraits<RealScalar>::epsilon()*diag(i) << "\n";
+ std::cout << "deflation 4.4 with i = " << i << " because " << diag(i) << " - " << diag(i-1) << " == " << (diag(i) - diag(i-1)) << " < " << NumTraits<RealScalar>::epsilon()*/*diag(i)*/maxDiag << "\n";
#endif
eigen_internal_assert(abs(diag(i) - diag(i-1))<epsilon_coarse && " diagonal entries are not properly sorted");
deflation44(firstCol, firstCol + shift, firstRowW, firstColW, i-1, i, length);
@@ -1226,7 +1348,6 @@
#endif
}//end deflation
-#ifndef __CUDACC__
/** \svd_module
*
* \return the singular value decomposition of \c *this computed by Divide & Conquer algorithm
@@ -1239,7 +1360,6 @@
{
return BDCSVD<PlainObject>(*this, computationOptions);
}
-#endif
} // end namespace Eigen
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/JacobiSVD.h b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/JacobiSVD.h
index 43488b1..9d95acd 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/JacobiSVD.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/JacobiSVD.h
@@ -112,12 +112,12 @@
ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
- TrOptions = RowsAtCompileTime==1 ? (MatrixType::Options & ~(RowMajor))
- : ColsAtCompileTime==1 ? (MatrixType::Options | RowMajor)
- : MatrixType::Options
+ Options = MatrixType::Options
};
- typedef Matrix<Scalar, ColsAtCompileTime, RowsAtCompileTime, TrOptions, MaxColsAtCompileTime, MaxRowsAtCompileTime>
- TransposeTypeWithSameStorageOrder;
+
+ typedef typename internal::make_proper_matrix_type<
+ Scalar, ColsAtCompileTime, RowsAtCompileTime, Options, MaxColsAtCompileTime, MaxRowsAtCompileTime
+ >::type TransposeTypeWithSameStorageOrder;
void allocate(const JacobiSVD<MatrixType, FullPivHouseholderQRPreconditioner>& svd)
{
@@ -202,13 +202,12 @@
ColsAtCompileTime = MatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
- TrOptions = RowsAtCompileTime==1 ? (MatrixType::Options & ~(RowMajor))
- : ColsAtCompileTime==1 ? (MatrixType::Options | RowMajor)
- : MatrixType::Options
+ Options = MatrixType::Options
};
- typedef Matrix<Scalar, ColsAtCompileTime, RowsAtCompileTime, TrOptions, MaxColsAtCompileTime, MaxRowsAtCompileTime>
- TransposeTypeWithSameStorageOrder;
+ typedef typename internal::make_proper_matrix_type<
+ Scalar, ColsAtCompileTime, RowsAtCompileTime, Options, MaxColsAtCompileTime, MaxRowsAtCompileTime
+ >::type TransposeTypeWithSameStorageOrder;
void allocate(const JacobiSVD<MatrixType, ColPivHouseholderQRPreconditioner>& svd)
{
@@ -303,8 +302,9 @@
Options = MatrixType::Options
};
- typedef Matrix<Scalar, ColsAtCompileTime, RowsAtCompileTime, Options, MaxColsAtCompileTime, MaxRowsAtCompileTime>
- TransposeTypeWithSameStorageOrder;
+ typedef typename internal::make_proper_matrix_type<
+ Scalar, ColsAtCompileTime, RowsAtCompileTime, Options, MaxColsAtCompileTime, MaxRowsAtCompileTime
+ >::type TransposeTypeWithSameStorageOrder;
void allocate(const JacobiSVD<MatrixType, HouseholderQRPreconditioner>& svd)
{
@@ -425,6 +425,7 @@
template<typename _MatrixType, int QRPreconditioner>
struct traits<JacobiSVD<_MatrixType,QRPreconditioner> >
+ : traits<_MatrixType>
{
typedef _MatrixType MatrixType;
};
@@ -584,6 +585,7 @@
using Base::m_matrixU;
using Base::m_matrixV;
using Base::m_singularValues;
+ using Base::m_info;
using Base::m_isInitialized;
using Base::m_isAllocated;
using Base::m_usePrescribedThreshold;
@@ -610,7 +612,7 @@
};
template<typename MatrixType, int QRPreconditioner>
-void JacobiSVD<MatrixType, QRPreconditioner>::allocate(Index rows, Index cols, unsigned int computationOptions)
+void JacobiSVD<MatrixType, QRPreconditioner>::allocate(Eigen::Index rows, Eigen::Index cols, unsigned int computationOptions)
{
eigen_assert(rows >= 0 && cols >= 0);
@@ -624,6 +626,7 @@
m_rows = rows;
m_cols = cols;
+ m_info = Success;
m_isInitialized = false;
m_isAllocated = true;
m_computationOptions = computationOptions;
@@ -673,7 +676,12 @@
const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
// Scaling factor to reduce over/under-flows
- RealScalar scale = matrix.cwiseAbs().maxCoeff();
+ RealScalar scale = matrix.cwiseAbs().template maxCoeff<PropagateNaN>();
+ if (!(numext::isfinite)(scale)) {
+ m_isInitialized = true;
+ m_info = InvalidInput;
+ return *this;
+ }
if(scale==RealScalar(0)) scale = RealScalar(1);
/*** step 1. The R-SVD step: we use a QR decomposition to reduce to the case of a square matrix */
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/SVDBase.h b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/SVDBase.h
index 3d1ef37..bc7ab88 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/SVDBase.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/SVDBase.h
@@ -17,6 +17,18 @@
#define EIGEN_SVDBASE_H
namespace Eigen {
+
+namespace internal {
+template<typename Derived> struct traits<SVDBase<Derived> >
+ : traits<Derived>
+{
+ typedef MatrixXpr XprKind;
+ typedef SolverStorage StorageKind;
+ typedef int StorageIndex;
+ enum { Flags = 0 };
+};
+}
+
/** \ingroup SVD_Module
*
*
@@ -39,20 +51,26 @@
* smaller value among \a n and \a p, there are only \a m singular vectors; the remaining columns of \a U and \a V do not correspond to actual
* singular vectors. Asking for \em thin \a U or \a V means asking for only their \a m first columns to be formed. So \a U is then a n-by-m matrix,
* and \a V is then a p-by-m matrix. Notice that thin \a U and \a V are all you need for (least squares) solving.
+ *
+ * The status of the computation can be retrived using the \a info() method. Unless \a info() returns \a Success, the results should be not
+ * considered well defined.
*
- * If the input matrix has inf or nan coefficients, the result of the computation is undefined, but the computation is guaranteed to
+ * If the input matrix has inf or nan coefficients, the result of the computation is undefined, and \a info() will return \a InvalidInput, but the computation is guaranteed to
* terminate in finite (and reasonable) time.
* \sa class BDCSVD, class JacobiSVD
*/
-template<typename Derived>
-class SVDBase
+template<typename Derived> class SVDBase
+ : public SolverBase<SVDBase<Derived> >
{
+public:
+
+ template<typename Derived_>
+ friend struct internal::solve_assertion;
-public:
typedef typename internal::traits<Derived>::MatrixType MatrixType;
typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
- typedef typename MatrixType::StorageIndex StorageIndex;
+ typedef typename Eigen::internal::traits<SVDBase>::StorageIndex StorageIndex;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
enum {
RowsAtCompileTime = MatrixType::RowsAtCompileTime,
@@ -82,7 +100,7 @@
*/
const MatrixUType& matrixU() const
{
- eigen_assert(m_isInitialized && "SVD is not initialized.");
+ _check_compute_assertions();
eigen_assert(computeU() && "This SVD decomposition didn't compute U. Did you ask for it?");
return m_matrixU;
}
@@ -98,7 +116,7 @@
*/
const MatrixVType& matrixV() const
{
- eigen_assert(m_isInitialized && "SVD is not initialized.");
+ _check_compute_assertions();
eigen_assert(computeV() && "This SVD decomposition didn't compute V. Did you ask for it?");
return m_matrixV;
}
@@ -110,14 +128,14 @@
*/
const SingularValuesType& singularValues() const
{
- eigen_assert(m_isInitialized && "SVD is not initialized.");
+ _check_compute_assertions();
return m_singularValues;
}
/** \returns the number of singular values that are not exactly 0 */
Index nonzeroSingularValues() const
{
- eigen_assert(m_isInitialized && "SVD is not initialized.");
+ _check_compute_assertions();
return m_nonzeroSingularValues;
}
@@ -130,7 +148,7 @@
inline Index rank() const
{
using std::abs;
- eigen_assert(m_isInitialized && "JacobiSVD is not initialized.");
+ _check_compute_assertions();
if(m_singularValues.size()==0) return 0;
RealScalar premultiplied_threshold = numext::maxi<RealScalar>(m_singularValues.coeff(0) * threshold(), (std::numeric_limits<RealScalar>::min)());
Index i = m_nonzeroSingularValues-1;
@@ -183,7 +201,7 @@
// this temporary is needed to workaround a MSVC issue
Index diagSize = (std::max<Index>)(1,m_diagSize);
return m_usePrescribedThreshold ? m_prescribedThreshold
- : diagSize*NumTraits<Scalar>::epsilon();
+ : RealScalar(diagSize)*NumTraits<Scalar>::epsilon();
}
/** \returns true if \a U (full or thin) is asked for in this SVD decomposition */
@@ -194,6 +212,7 @@
inline Index rows() const { return m_rows; }
inline Index cols() const { return m_cols; }
+ #ifdef EIGEN_PARSED_BY_DOXYGEN
/** \returns a (least squares) solution of \f$ A x = b \f$ using the current SVD decomposition of A.
*
* \param b the right-hand-side of the equation to solve.
@@ -205,32 +224,55 @@
*/
template<typename Rhs>
inline const Solve<Derived, Rhs>
- solve(const MatrixBase<Rhs>& b) const
+ solve(const MatrixBase<Rhs>& b) const;
+ #endif
+
+
+ /** \brief Reports whether previous computation was successful.
+ *
+ * \returns \c Success if computation was successful.
+ */
+ EIGEN_DEVICE_FUNC
+ ComputationInfo info() const
{
eigen_assert(m_isInitialized && "SVD is not initialized.");
- eigen_assert(computeU() && computeV() && "SVD::solve() requires both unitaries U and V to be computed (thin unitaries suffice).");
- return Solve<Derived, Rhs>(derived(), b.derived());
+ return m_info;
}
-
+
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
- EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
+
+ template<bool Conjugate, typename RhsType, typename DstType>
+ void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif
protected:
-
+
static void check_template_parameters()
{
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
}
-
+
+ void _check_compute_assertions() const {
+ eigen_assert(m_isInitialized && "SVD is not initialized.");
+ }
+
+ template<bool Transpose_, typename Rhs>
+ void _check_solve_assertion(const Rhs& b) const {
+ EIGEN_ONLY_USED_FOR_DEBUG(b);
+ _check_compute_assertions();
+ eigen_assert(computeU() && computeV() && "SVDBase::solve(): Both unitaries U and V are required to be computed (thin unitaries suffice).");
+ eigen_assert((Transpose_?cols():rows())==b.rows() && "SVDBase::solve(): invalid number of rows of the right hand side matrix b");
+ }
+
// return true if already allocated
bool allocate(Index rows, Index cols, unsigned int computationOptions) ;
MatrixUType m_matrixU;
MatrixVType m_matrixV;
SingularValuesType m_singularValues;
+ ComputationInfo m_info;
bool m_isInitialized, m_isAllocated, m_usePrescribedThreshold;
bool m_computeFullU, m_computeThinU;
bool m_computeFullV, m_computeThinV;
@@ -243,9 +285,14 @@
* Default constructor of SVDBase
*/
SVDBase()
- : m_isInitialized(false),
+ : m_info(Success),
+ m_isInitialized(false),
m_isAllocated(false),
m_usePrescribedThreshold(false),
+ m_computeFullU(false),
+ m_computeThinU(false),
+ m_computeFullV(false),
+ m_computeThinV(false),
m_computationOptions(0),
m_rows(-1), m_cols(-1), m_diagSize(0)
{
@@ -260,17 +307,30 @@
template<typename RhsType, typename DstType>
void SVDBase<Derived>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
- eigen_assert(rhs.rows() == rows());
-
// A = U S V^*
// So A^{-1} = V S^{-1} U^*
- Matrix<Scalar, Dynamic, RhsType::ColsAtCompileTime, 0, MatrixType::MaxRowsAtCompileTime, RhsType::MaxColsAtCompileTime> tmp;
+ Matrix<typename RhsType::Scalar, Dynamic, RhsType::ColsAtCompileTime, 0, MatrixType::MaxRowsAtCompileTime, RhsType::MaxColsAtCompileTime> tmp;
Index l_rank = rank();
tmp.noalias() = m_matrixU.leftCols(l_rank).adjoint() * rhs;
tmp = m_singularValues.head(l_rank).asDiagonal().inverse() * tmp;
dst = m_matrixV.leftCols(l_rank) * tmp;
}
+
+template<typename Derived>
+template<bool Conjugate, typename RhsType, typename DstType>
+void SVDBase<Derived>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
+{
+ // A = U S V^*
+ // So A^{-*} = U S^{-1} V^*
+ // And A^{-T} = U_conj S^{-1} V^T
+ Matrix<typename RhsType::Scalar, Dynamic, RhsType::ColsAtCompileTime, 0, MatrixType::MaxRowsAtCompileTime, RhsType::MaxColsAtCompileTime> tmp;
+ Index l_rank = rank();
+
+ tmp.noalias() = m_matrixV.leftCols(l_rank).transpose().template conjugateIf<Conjugate>() * rhs;
+ tmp = m_singularValues.head(l_rank).asDiagonal().inverse() * tmp;
+ dst = m_matrixU.template conjugateIf<!Conjugate>().leftCols(l_rank) * tmp;
+}
#endif
template<typename MatrixType>
@@ -288,6 +348,7 @@
m_rows = rows;
m_cols = cols;
+ m_info = Success;
m_isInitialized = false;
m_isAllocated = true;
m_computationOptions = computationOptions;
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/UpperBidiagonalization.h b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/UpperBidiagonalization.h
index 11ac847..997defc 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/UpperBidiagonalization.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/SVD/UpperBidiagonalization.h
@@ -127,7 +127,7 @@
.makeHouseholderInPlace(mat.coeffRef(k,k+1), upper_diagonal[k]);
// apply householder transform to remaining part of mat on the left
mat.bottomRightCorner(remainingRows-1, remainingCols)
- .applyHouseholderOnTheRight(mat.row(k).tail(remainingCols-1).transpose(), mat.coeff(k,k+1), tempData);
+ .applyHouseholderOnTheRight(mat.row(k).tail(remainingCols-1).adjoint(), mat.coeff(k,k+1), tempData);
}
}
@@ -202,7 +202,7 @@
{
SubColumnType y_k( Y.col(k).tail(remainingCols) );
- // let's use the begining of column k of Y as a temporary vector
+ // let's use the beginning of column k of Y as a temporary vector
SubColumnType tmp( Y.col(k).head(k) );
y_k.noalias() = A.block(k,k+1, remainingRows,remainingCols).adjoint() * v_k; // bottleneck
tmp.noalias() = V_k1.adjoint() * v_k;
@@ -231,7 +231,7 @@
{
SubColumnType x_k ( X.col(k).tail(remainingRows-1) );
- // let's use the begining of column k of X as a temporary vectors
+ // let's use the beginning of column k of X as a temporary vectors
// note that tmp0 and tmp1 overlaps
SubColumnType tmp0 ( X.col(k).head(k) ),
tmp1 ( X.col(k).head(k+1) );
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdDeque.h b/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdDeque.h
index cf1fedf..6d47e75 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdDeque.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdDeque.h
@@ -36,7 +36,7 @@
deque(InputIterator first, InputIterator last, const allocator_type& a = allocator_type()) : deque_base(first, last, a) {} \
deque(const deque& c) : deque_base(c) {} \
explicit deque(size_type num, const value_type& val = value_type()) : deque_base(num, val) {} \
- deque(iterator start, iterator end) : deque_base(start, end) {} \
+ deque(iterator start_, iterator end_) : deque_base(start_, end_) {} \
deque& operator=(const deque& x) { \
deque_base::operator=(x); \
return *this; \
@@ -62,7 +62,7 @@
: deque_base(first, last, a) {} \
deque(const deque& c) : deque_base(c) {} \
explicit deque(size_type num, const value_type& val = value_type()) : deque_base(num, val) {} \
- deque(iterator start, iterator end) : deque_base(start, end) {} \
+ deque(iterator start_, iterator end_) : deque_base(start_, end_) {} \
deque& operator=(const deque& x) { \
deque_base::operator=(x); \
return *this; \
@@ -98,17 +98,7 @@
{ return deque_base::insert(position,x); }
void insert(const_iterator position, size_type new_size, const value_type& x)
{ deque_base::insert(position, new_size, x); }
-#elif defined(_GLIBCXX_DEQUE) && EIGEN_GNUC_AT_LEAST(4,2)
- // workaround GCC std::deque implementation
- void resize(size_type new_size, const value_type& x)
- {
- if (new_size < deque_base::size())
- deque_base::_M_erase_at_end(this->_M_impl._M_start + new_size);
- else
- deque_base::insert(deque_base::end(), new_size - deque_base::size(), x);
- }
#else
- // either GCC 4.1 or non-GCC
// default implementation which should always work.
void resize(size_type new_size, const value_type& x)
{
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdList.h b/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdList.h
index e1eba49..8ba3fad 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdList.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdList.h
@@ -35,7 +35,7 @@
list(InputIterator first, InputIterator last, const allocator_type& a = allocator_type()) : list_base(first, last, a) {} \
list(const list& c) : list_base(c) {} \
explicit list(size_type num, const value_type& val = value_type()) : list_base(num, val) {} \
- list(iterator start, iterator end) : list_base(start, end) {} \
+ list(iterator start_, iterator end_) : list_base(start_, end_) {} \
list& operator=(const list& x) { \
list_base::operator=(x); \
return *this; \
@@ -62,7 +62,7 @@
: list_base(first, last, a) {} \
list(const list& c) : list_base(c) {} \
explicit list(size_type num, const value_type& val = value_type()) : list_base(num, val) {} \
- list(iterator start, iterator end) : list_base(start, end) {} \
+ list(iterator start_, iterator end_) : list_base(start_, end_) {} \
list& operator=(const list& x) { \
list_base::operator=(x); \
return *this; \
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdVector.h b/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdVector.h
index ec22821..9fcf19b 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdVector.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/StlSupport/StdVector.h
@@ -36,7 +36,7 @@
vector(InputIterator first, InputIterator last, const allocator_type& a = allocator_type()) : vector_base(first, last, a) {} \
vector(const vector& c) : vector_base(c) {} \
explicit vector(size_type num, const value_type& val = value_type()) : vector_base(num, val) {} \
- vector(iterator start, iterator end) : vector_base(start, end) {} \
+ vector(iterator start_, iterator end_) : vector_base(start_, end_) {} \
vector& operator=(const vector& x) { \
vector_base::operator=(x); \
return *this; \
@@ -62,7 +62,7 @@
: vector_base(first, last, a) {} \
vector(const vector& c) : vector_base(c) {} \
explicit vector(size_type num, const value_type& val = value_type()) : vector_base(num, val) {} \
- vector(iterator start, iterator end) : vector_base(start, end) {} \
+ vector(iterator start_, iterator end_) : vector_base(start_, end_) {} \
vector& operator=(const vector& x) { \
vector_base::operator=(x); \
return *this; \
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseBinaryOps.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseBinaryOps.h
index 1f8a531..0e5d544 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseBinaryOps.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseBinaryOps.h
@@ -75,6 +75,32 @@
return (max)(Derived::PlainObject::Constant(rows(), cols(), other));
}
+/** \returns an expression of the coefficient-wise absdiff of \c *this and \a other
+ *
+ * Example: \include Cwise_absolute_difference.cpp
+ * Output: \verbinclude Cwise_absolute_difference.out
+ *
+ * \sa absolute_difference()
+ */
+EIGEN_MAKE_CWISE_BINARY_OP(absolute_difference,absolute_difference)
+
+/** \returns an expression of the coefficient-wise absolute_difference of \c *this and scalar \a other
+ *
+ * \sa absolute_difference()
+ */
+EIGEN_DEVICE_FUNC
+EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_absolute_difference_op<Scalar,Scalar>, const Derived,
+ const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> >
+#ifdef EIGEN_PARSED_BY_DOXYGEN
+absolute_difference
+#else
+(absolute_difference)
+#endif
+(const Scalar &other) const
+{
+ return (absolute_difference)(Derived::PlainObject::Constant(rows(), cols(), other));
+}
+
/** \returns an expression of the coefficient-wise power of \c *this to the given array of \a exponents.
*
* This function computes the coefficient-wise power.
@@ -119,7 +145,7 @@
return this->OP(Derived::PlainObject::Constant(rows(), cols(), s)); \
} \
EIGEN_DEVICE_FUNC friend EIGEN_STRONG_INLINE const RCmp ## COMPARATOR ## ReturnType \
-OP(const Scalar& s, const Derived& d) { \
+OP(const Scalar& s, const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& d) { \
return Derived::PlainObject::Constant(d.rows(), d.cols(), s).OP(d); \
}
@@ -314,9 +340,9 @@
*
* It returns the Riemann zeta function of two arguments \c *this and \a q:
*
- * \param *this is the exposent, it must be > 1
* \param q is the shift, it must be > 0
*
+ * \note *this is the exponent, it must be > 1.
* \note This function supports only float and double scalar types. To support other scalar types, the user has
* to provide implementations of zeta(T,T) for any scalar type T to be supported.
*
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseUnaryOps.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseUnaryOps.h
index ebaa3f1..13c55f4 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseUnaryOps.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ArrayCwiseUnaryOps.h
@@ -10,9 +10,11 @@
typedef CwiseUnaryOp<internal::scalar_boolean_not_op<Scalar>, const Derived> BooleanNotReturnType;
typedef CwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived> ExpReturnType;
+typedef CwiseUnaryOp<internal::scalar_expm1_op<Scalar>, const Derived> Expm1ReturnType;
typedef CwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived> LogReturnType;
typedef CwiseUnaryOp<internal::scalar_log1p_op<Scalar>, const Derived> Log1pReturnType;
typedef CwiseUnaryOp<internal::scalar_log10_op<Scalar>, const Derived> Log10ReturnType;
+typedef CwiseUnaryOp<internal::scalar_log2_op<Scalar>, const Derived> Log2ReturnType;
typedef CwiseUnaryOp<internal::scalar_cos_op<Scalar>, const Derived> CosReturnType;
typedef CwiseUnaryOp<internal::scalar_sin_op<Scalar>, const Derived> SinReturnType;
typedef CwiseUnaryOp<internal::scalar_tan_op<Scalar>, const Derived> TanReturnType;
@@ -20,11 +22,18 @@
typedef CwiseUnaryOp<internal::scalar_asin_op<Scalar>, const Derived> AsinReturnType;
typedef CwiseUnaryOp<internal::scalar_atan_op<Scalar>, const Derived> AtanReturnType;
typedef CwiseUnaryOp<internal::scalar_tanh_op<Scalar>, const Derived> TanhReturnType;
+typedef CwiseUnaryOp<internal::scalar_logistic_op<Scalar>, const Derived> LogisticReturnType;
typedef CwiseUnaryOp<internal::scalar_sinh_op<Scalar>, const Derived> SinhReturnType;
+#if EIGEN_HAS_CXX11_MATH
+typedef CwiseUnaryOp<internal::scalar_atanh_op<Scalar>, const Derived> AtanhReturnType;
+typedef CwiseUnaryOp<internal::scalar_asinh_op<Scalar>, const Derived> AsinhReturnType;
+typedef CwiseUnaryOp<internal::scalar_acosh_op<Scalar>, const Derived> AcoshReturnType;
+#endif
typedef CwiseUnaryOp<internal::scalar_cosh_op<Scalar>, const Derived> CoshReturnType;
typedef CwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> SquareReturnType;
typedef CwiseUnaryOp<internal::scalar_cube_op<Scalar>, const Derived> CubeReturnType;
typedef CwiseUnaryOp<internal::scalar_round_op<Scalar>, const Derived> RoundReturnType;
+typedef CwiseUnaryOp<internal::scalar_rint_op<Scalar>, const Derived> RintReturnType;
typedef CwiseUnaryOp<internal::scalar_floor_op<Scalar>, const Derived> FloorReturnType;
typedef CwiseUnaryOp<internal::scalar_ceil_op<Scalar>, const Derived> CeilReturnType;
typedef CwiseUnaryOp<internal::scalar_isnan_op<Scalar>, const Derived> IsNaNReturnType;
@@ -90,6 +99,20 @@
return ExpReturnType(derived());
}
+/** \returns an expression of the coefficient-wise exponential of *this minus 1.
+ *
+ * In exact arithmetic, \c x.expm1() is equivalent to \c x.exp() - 1,
+ * however, with finite precision, this function is much more accurate when \c x is close to zero.
+ *
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_expm1">Math functions</a>, exp()
+ */
+EIGEN_DEVICE_FUNC
+inline const Expm1ReturnType
+expm1() const
+{
+ return Expm1ReturnType(derived());
+}
+
/** \returns an expression of the coefficient-wise logarithm of *this.
*
* This function computes the coefficient-wise logarithm. The function MatrixBase::log() in the
@@ -98,7 +121,7 @@
* Example: \include Cwise_log.cpp
* Output: \verbinclude Cwise_log.out
*
- * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_log">Math functions</a>, exp()
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_log">Math functions</a>, log()
*/
EIGEN_DEVICE_FUNC
inline const LogReturnType
@@ -137,6 +160,18 @@
return Log10ReturnType(derived());
}
+/** \returns an expression of the coefficient-wise base-2 logarithm of *this.
+ *
+ * This function computes the coefficient-wise base-2 logarithm.
+ *
+ */
+EIGEN_DEVICE_FUNC
+inline const Log2ReturnType
+log2() const
+{
+ return Log2ReturnType(derived());
+}
+
/** \returns an expression of the coefficient-wise square root of *this.
*
* This function computes the coefficient-wise square root. The function MatrixBase::sqrt() in the
@@ -311,7 +346,7 @@
* Example: \include Cwise_cosh.cpp
* Output: \verbinclude Cwise_cosh.out
*
- * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_cosh">Math functions</a>, tan(), sinh(), cosh()
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_cosh">Math functions</a>, tanh(), sinh(), cosh()
*/
EIGEN_DEVICE_FUNC
inline const CoshReturnType
@@ -320,6 +355,50 @@
return CoshReturnType(derived());
}
+#if EIGEN_HAS_CXX11_MATH
+/** \returns an expression of the coefficient-wise inverse hyperbolic tan of *this.
+ *
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_atanh">Math functions</a>, atanh(), asinh(), acosh()
+ */
+EIGEN_DEVICE_FUNC
+inline const AtanhReturnType
+atanh() const
+{
+ return AtanhReturnType(derived());
+}
+
+/** \returns an expression of the coefficient-wise inverse hyperbolic sin of *this.
+ *
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_asinh">Math functions</a>, atanh(), asinh(), acosh()
+ */
+EIGEN_DEVICE_FUNC
+inline const AsinhReturnType
+asinh() const
+{
+ return AsinhReturnType(derived());
+}
+
+/** \returns an expression of the coefficient-wise inverse hyperbolic cos of *this.
+ *
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_acosh">Math functions</a>, atanh(), asinh(), acosh()
+ */
+EIGEN_DEVICE_FUNC
+inline const AcoshReturnType
+acosh() const
+{
+ return AcoshReturnType(derived());
+}
+#endif
+
+/** \returns an expression of the coefficient-wise logistic of *this.
+ */
+EIGEN_DEVICE_FUNC
+inline const LogisticReturnType
+logistic() const
+{
+ return LogisticReturnType(derived());
+}
+
/** \returns an expression of the coefficient-wise inverse of *this.
*
* Example: \include Cwise_inverse.cpp
@@ -362,6 +441,20 @@
return CubeReturnType(derived());
}
+/** \returns an expression of the coefficient-wise rint of *this.
+ *
+ * Example: \include Cwise_rint.cpp
+ * Output: \verbinclude Cwise_rint.out
+ *
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_rint">Math functions</a>, ceil(), floor()
+ */
+EIGEN_DEVICE_FUNC
+inline const RintReturnType
+rint() const
+{
+ return RintReturnType(derived());
+}
+
/** \returns an expression of the coefficient-wise round of *this.
*
* Example: \include Cwise_round.cpp
@@ -404,6 +497,45 @@
return CeilReturnType(derived());
}
+template<int N> struct ShiftRightXpr {
+ typedef CwiseUnaryOp<internal::scalar_shift_right_op<Scalar, N>, const Derived> Type;
+};
+
+/** \returns an expression of \c *this with the \a Scalar type arithmetically
+ * shifted right by \a N bit positions.
+ *
+ * The template parameter \a N specifies the number of bit positions to shift.
+ *
+ * \sa shiftLeft()
+ */
+template<int N>
+EIGEN_DEVICE_FUNC
+typename ShiftRightXpr<N>::Type
+shiftRight() const
+{
+ return typename ShiftRightXpr<N>::Type(derived());
+}
+
+
+template<int N> struct ShiftLeftXpr {
+ typedef CwiseUnaryOp<internal::scalar_shift_left_op<Scalar, N>, const Derived> Type;
+};
+
+/** \returns an expression of \c *this with the \a Scalar type logically
+ * shifted left by \a N bit positions.
+ *
+ * The template parameter \a N specifies the number of bit positions to shift.
+ *
+ * \sa shiftRight()
+ */
+template<int N>
+EIGEN_DEVICE_FUNC
+typename ShiftLeftXpr<N>::Type
+shiftLeft() const
+{
+ return typename ShiftLeftXpr<N>::Type(derived());
+}
+
/** \returns an expression of the coefficient-wise isnan of *this.
*
* Example: \include Cwise_isNaN.cpp
@@ -471,14 +603,12 @@
typedef CwiseUnaryOp<internal::scalar_digamma_op<Scalar>, const Derived> DigammaReturnType;
typedef CwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> ErfReturnType;
typedef CwiseUnaryOp<internal::scalar_erfc_op<Scalar>, const Derived> ErfcReturnType;
+typedef CwiseUnaryOp<internal::scalar_ndtri_op<Scalar>, const Derived> NdtriReturnType;
/** \cpp11 \returns an expression of the coefficient-wise ln(|gamma(*this)|).
*
* \specialfunctions_module
*
- * Example: \include Cwise_lgamma.cpp
- * Output: \verbinclude Cwise_lgamma.out
- *
* \note This function supports only float and double scalar types in c++11 mode. To support other scalar types,
* or float/double in non c++11 mode, the user has to provide implementations of lgamma(T) for any scalar
* type T to be supported.
@@ -514,9 +644,6 @@
*
* \specialfunctions_module
*
- * Example: \include Cwise_erf.cpp
- * Output: \verbinclude Cwise_erf.out
- *
* \note This function supports only float and double scalar types in c++11 mode. To support other scalar types,
* or float/double in non c++11 mode, the user has to provide implementations of erf(T) for any scalar
* type T to be supported.
@@ -535,9 +662,6 @@
*
* \specialfunctions_module
*
- * Example: \include Cwise_erfc.cpp
- * Output: \verbinclude Cwise_erfc.out
- *
* \note This function supports only float and double scalar types in c++11 mode. To support other scalar types,
* or float/double in non c++11 mode, the user has to provide implementations of erfc(T) for any scalar
* type T to be supported.
@@ -550,3 +674,23 @@
{
return ErfcReturnType(derived());
}
+
+/** \returns an expression of the coefficient-wise inverse of the CDF of the Normal distribution function
+ * function of *this.
+ *
+ * \specialfunctions_module
+ *
+ * In other words, considering `x = ndtri(y)`, it returns the argument, x, for which the area under the
+ * Gaussian probability density function (integrated from minus infinity to x) is equal to y.
+ *
+ * \note This function supports only float and double scalar types. To support other scalar types,
+ * the user has to provide implementations of ndtri(T) for any scalar type T to be supported.
+ *
+ * \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_ndtri">Math functions</a>
+ */
+EIGEN_DEVICE_FUNC
+inline const NdtriReturnType
+ndtri() const
+{
+ return NdtriReturnType(derived());
+}
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/BlockMethods.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/BlockMethods.h
index ac35a00..63a52a6 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/BlockMethods.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/BlockMethods.h
@@ -40,68 +40,126 @@
template<int Size> struct FixedSegmentReturnType { typedef VectorBlock<Derived, Size> Type; };
template<int Size> struct ConstFixedSegmentReturnType { typedef const VectorBlock<const Derived, Size> Type; };
+/// \internal inner-vector
+typedef Block<Derived,IsRowMajor?1:Dynamic,IsRowMajor?Dynamic:1,true> InnerVectorReturnType;
+typedef Block<const Derived,IsRowMajor?1:Dynamic,IsRowMajor?Dynamic:1,true> ConstInnerVectorReturnType;
+
+/// \internal set of inner-vectors
+typedef Block<Derived,Dynamic,Dynamic,true> InnerVectorsReturnType;
+typedef Block<const Derived,Dynamic,Dynamic,true> ConstInnerVectorsReturnType;
+
#endif // not EIGEN_PARSED_BY_DOXYGEN
-/// \returns a dynamic-size expression of a block in *this.
+/// \returns an expression of a block in \c *this with either dynamic or fixed sizes.
///
-/// \param startRow the first row in the block
-/// \param startCol the first column in the block
-/// \param blockRows the number of rows in the block
-/// \param blockCols the number of columns in the block
+/// \param startRow the first row in the block
+/// \param startCol the first column in the block
+/// \param blockRows number of rows in the block, specified at either run-time or compile-time
+/// \param blockCols number of columns in the block, specified at either run-time or compile-time
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
-/// Example: \include MatrixBase_block_int_int_int_int.cpp
+/// Example using runtime (aka dynamic) sizes: \include MatrixBase_block_int_int_int_int.cpp
/// Output: \verbinclude MatrixBase_block_int_int_int_int.out
///
-/// \note Even though the returned expression has dynamic size, in the case
+/// \newin{3.4}:
+///
+/// The number of rows \a blockRows and columns \a blockCols can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments. In the later case, \c n plays the role of a runtime fallback value in case \c N equals Eigen::Dynamic.
+/// Here is an example with a fixed number of rows \c NRows and dynamic number of columns \c cols:
+/// \code
+/// mat.block(i,j,fix<NRows>,cols)
+/// \endcode
+///
+/// This function thus fully covers the features offered by the following overloads block<NRows,NCols>(Index, Index),
+/// and block<NRows,NCols>(Index, Index, Index, Index) that are thus obsolete. Indeed, this generic version avoids
+/// redundancy, it preserves the argument order, and prevents the need to rely on the template keyword in templated code.
+///
+/// but with less redundancy and more consistency as it does not modify the argument order
+/// and seamlessly enable hybrid fixed/dynamic sizes.
+///
+/// \note Even in the case that the returned expression has dynamic size, in the case
/// when it is applied to a fixed-size matrix, it inherits a fixed maximal size,
/// which means that evaluating it does not cause a dynamic memory allocation.
///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index)
+/// \sa class Block, fix, fix<N>(int)
///
-EIGEN_DEVICE_FUNC
-inline BlockXpr block(Index startRow, Index startCol, Index blockRows, Index blockCols)
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename FixedBlockXpr<...,...>::Type
+#endif
+block(Index startRow, Index startCol, NRowsType blockRows, NColsType blockCols)
{
- return BlockXpr(derived(), startRow, startCol, blockRows, blockCols);
+ return typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type(
+ derived(), startRow, startCol, internal::get_runtime_value(blockRows), internal::get_runtime_value(blockCols));
}
-/// This is the const version of block(Index,Index,Index,Index). */
-EIGEN_DEVICE_FUNC
-inline const ConstBlockXpr block(Index startRow, Index startCol, Index blockRows, Index blockCols) const
+/// This is the const version of block(Index,Index,NRowsType,NColsType)
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstFixedBlockXpr<...,...>::Type
+#endif
+block(Index startRow, Index startCol, NRowsType blockRows, NColsType blockCols) const
{
- return ConstBlockXpr(derived(), startRow, startCol, blockRows, blockCols);
+ return typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type(
+ derived(), startRow, startCol, internal::get_runtime_value(blockRows), internal::get_runtime_value(blockCols));
}
-
-/// \returns a dynamic-size expression of a top-right corner of *this.
+/// \returns a expression of a top-right corner of \c *this with either dynamic or fixed sizes.
///
/// \param cRows the number of rows in the corner
/// \param cCols the number of columns in the corner
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
-/// Example: \include MatrixBase_topRightCorner_int_int.cpp
+/// Example with dynamic sizes: \include MatrixBase_topRightCorner_int_int.cpp
/// Output: \verbinclude MatrixBase_topRightCorner_int_int.out
///
+/// The number of rows \a blockRows and columns \a blockCols can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments. See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline BlockXpr topRightCorner(Index cRows, Index cCols)
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename FixedBlockXpr<...,...>::Type
+#endif
+topRightCorner(NRowsType cRows, NColsType cCols)
{
- return BlockXpr(derived(), 0, cols() - cCols, cRows, cCols);
+ return typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, cols() - internal::get_runtime_value(cCols), internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// This is the const version of topRightCorner(Index, Index).
-EIGEN_DEVICE_FUNC
-inline const ConstBlockXpr topRightCorner(Index cRows, Index cCols) const
+/// This is the const version of topRightCorner(NRowsType, NColsType).
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstFixedBlockXpr<...,...>::Type
+#endif
+topRightCorner(NRowsType cRows, NColsType cCols) const
{
- return ConstBlockXpr(derived(), 0, cols() - cCols, cRows, cCols);
+ return typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, cols() - internal::get_runtime_value(cCols), internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// \returns an expression of a fixed-size top-right corner of *this.
+/// \returns an expression of a fixed-size top-right corner of \c *this.
///
/// \tparam CRows the number of rows in the corner
/// \tparam CCols the number of columns in the corner
@@ -114,21 +172,21 @@
/// \sa class Block, block<int,int>(Index,Index)
///
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline typename FixedBlockXpr<CRows,CCols>::Type topRightCorner()
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type topRightCorner()
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), 0, cols() - CCols);
}
/// This is the const version of topRightCorner<int, int>().
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type topRightCorner() const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type topRightCorner() const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), 0, cols() - CCols);
}
-/// \returns an expression of a top-right corner of *this.
+/// \returns an expression of a top-right corner of \c *this.
///
/// \tparam CRows number of rows in corner as specified at compile-time
/// \tparam CCols number of columns in corner as specified at compile-time
@@ -148,46 +206,67 @@
/// \sa class Block
///
template<int CRows, int CCols>
-inline typename FixedBlockXpr<CRows,CCols>::Type topRightCorner(Index cRows, Index cCols)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type topRightCorner(Index cRows, Index cCols)
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), 0, cols() - cCols, cRows, cCols);
}
/// This is the const version of topRightCorner<int, int>(Index, Index).
template<int CRows, int CCols>
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type topRightCorner(Index cRows, Index cCols) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type topRightCorner(Index cRows, Index cCols) const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), 0, cols() - cCols, cRows, cCols);
}
-/// \returns a dynamic-size expression of a top-left corner of *this.
+/// \returns an expression of a top-left corner of \c *this with either dynamic or fixed sizes.
///
/// \param cRows the number of rows in the corner
/// \param cCols the number of columns in the corner
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
/// Example: \include MatrixBase_topLeftCorner_int_int.cpp
/// Output: \verbinclude MatrixBase_topLeftCorner_int_int.out
///
+/// The number of rows \a blockRows and columns \a blockCols can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments. See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline BlockXpr topLeftCorner(Index cRows, Index cCols)
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename FixedBlockXpr<...,...>::Type
+#endif
+topLeftCorner(NRowsType cRows, NColsType cCols)
{
- return BlockXpr(derived(), 0, 0, cRows, cCols);
+ return typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, 0, internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
/// This is the const version of topLeftCorner(Index, Index).
-EIGEN_DEVICE_FUNC
-inline const ConstBlockXpr topLeftCorner(Index cRows, Index cCols) const
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstFixedBlockXpr<...,...>::Type
+#endif
+topLeftCorner(NRowsType cRows, NColsType cCols) const
{
- return ConstBlockXpr(derived(), 0, 0, cRows, cCols);
+ return typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, 0, internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// \returns an expression of a fixed-size top-left corner of *this.
+/// \returns an expression of a fixed-size top-left corner of \c *this.
///
/// The template parameters CRows and CCols are the number of rows and columns in the corner.
///
@@ -196,24 +275,24 @@
///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline typename FixedBlockXpr<CRows,CCols>::Type topLeftCorner()
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type topLeftCorner()
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), 0, 0);
}
/// This is the const version of topLeftCorner<int, int>().
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type topLeftCorner() const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type topLeftCorner() const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), 0, 0);
}
-/// \returns an expression of a top-left corner of *this.
+/// \returns an expression of a top-left corner of \c *this.
///
/// \tparam CRows number of rows in corner as specified at compile-time
/// \tparam CCols number of columns in corner as specified at compile-time
@@ -233,46 +312,69 @@
/// \sa class Block
///
template<int CRows, int CCols>
-inline typename FixedBlockXpr<CRows,CCols>::Type topLeftCorner(Index cRows, Index cCols)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type topLeftCorner(Index cRows, Index cCols)
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), 0, 0, cRows, cCols);
}
/// This is the const version of topLeftCorner<int, int>(Index, Index).
template<int CRows, int CCols>
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type topLeftCorner(Index cRows, Index cCols) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type topLeftCorner(Index cRows, Index cCols) const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), 0, 0, cRows, cCols);
}
-/// \returns a dynamic-size expression of a bottom-right corner of *this.
+/// \returns an expression of a bottom-right corner of \c *this with either dynamic or fixed sizes.
///
/// \param cRows the number of rows in the corner
/// \param cCols the number of columns in the corner
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
/// Example: \include MatrixBase_bottomRightCorner_int_int.cpp
/// Output: \verbinclude MatrixBase_bottomRightCorner_int_int.out
///
+/// The number of rows \a blockRows and columns \a blockCols can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments. See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline BlockXpr bottomRightCorner(Index cRows, Index cCols)
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename FixedBlockXpr<...,...>::Type
+#endif
+bottomRightCorner(NRowsType cRows, NColsType cCols)
{
- return BlockXpr(derived(), rows() - cRows, cols() - cCols, cRows, cCols);
+ return typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), rows() - internal::get_runtime_value(cRows), cols() - internal::get_runtime_value(cCols),
+ internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// This is the const version of bottomRightCorner(Index, Index).
-EIGEN_DEVICE_FUNC
-inline const ConstBlockXpr bottomRightCorner(Index cRows, Index cCols) const
+/// This is the const version of bottomRightCorner(NRowsType, NColsType).
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstFixedBlockXpr<...,...>::Type
+#endif
+bottomRightCorner(NRowsType cRows, NColsType cCols) const
{
- return ConstBlockXpr(derived(), rows() - cRows, cols() - cCols, cRows, cCols);
+ return typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), rows() - internal::get_runtime_value(cRows), cols() - internal::get_runtime_value(cCols),
+ internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// \returns an expression of a fixed-size bottom-right corner of *this.
+/// \returns an expression of a fixed-size bottom-right corner of \c *this.
///
/// The template parameters CRows and CCols are the number of rows and columns in the corner.
///
@@ -281,24 +383,24 @@
///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline typename FixedBlockXpr<CRows,CCols>::Type bottomRightCorner()
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type bottomRightCorner()
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), rows() - CRows, cols() - CCols);
}
/// This is the const version of bottomRightCorner<int, int>().
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomRightCorner() const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomRightCorner() const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), rows() - CRows, cols() - CCols);
}
-/// \returns an expression of a bottom-right corner of *this.
+/// \returns an expression of a bottom-right corner of \c *this.
///
/// \tparam CRows number of rows in corner as specified at compile-time
/// \tparam CCols number of columns in corner as specified at compile-time
@@ -318,46 +420,69 @@
/// \sa class Block
///
template<int CRows, int CCols>
-inline typename FixedBlockXpr<CRows,CCols>::Type bottomRightCorner(Index cRows, Index cCols)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type bottomRightCorner(Index cRows, Index cCols)
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), rows() - cRows, cols() - cCols, cRows, cCols);
}
/// This is the const version of bottomRightCorner<int, int>(Index, Index).
template<int CRows, int CCols>
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomRightCorner(Index cRows, Index cCols) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomRightCorner(Index cRows, Index cCols) const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), rows() - cRows, cols() - cCols, cRows, cCols);
}
-/// \returns a dynamic-size expression of a bottom-left corner of *this.
+/// \returns an expression of a bottom-left corner of \c *this with either dynamic or fixed sizes.
///
/// \param cRows the number of rows in the corner
/// \param cCols the number of columns in the corner
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
/// Example: \include MatrixBase_bottomLeftCorner_int_int.cpp
/// Output: \verbinclude MatrixBase_bottomLeftCorner_int_int.out
///
+/// The number of rows \a blockRows and columns \a blockCols can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments. See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline BlockXpr bottomLeftCorner(Index cRows, Index cCols)
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename FixedBlockXpr<...,...>::Type
+#endif
+bottomLeftCorner(NRowsType cRows, NColsType cCols)
{
- return BlockXpr(derived(), rows() - cRows, 0, cRows, cCols);
+ return typename FixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), rows() - internal::get_runtime_value(cRows), 0,
+ internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// This is the const version of bottomLeftCorner(Index, Index).
-EIGEN_DEVICE_FUNC
-inline const ConstBlockXpr bottomLeftCorner(Index cRows, Index cCols) const
+/// This is the const version of bottomLeftCorner(NRowsType, NColsType).
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename ConstFixedBlockXpr<...,...>::Type
+#endif
+bottomLeftCorner(NRowsType cRows, NColsType cCols) const
{
- return ConstBlockXpr(derived(), rows() - cRows, 0, cRows, cCols);
+ return typename ConstFixedBlockXpr<internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), rows() - internal::get_runtime_value(cRows), 0,
+ internal::get_runtime_value(cRows), internal::get_runtime_value(cCols));
}
-/// \returns an expression of a fixed-size bottom-left corner of *this.
+/// \returns an expression of a fixed-size bottom-left corner of \c *this.
///
/// The template parameters CRows and CCols are the number of rows and columns in the corner.
///
@@ -366,24 +491,24 @@
///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline typename FixedBlockXpr<CRows,CCols>::Type bottomLeftCorner()
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type bottomLeftCorner()
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), rows() - CRows, 0);
}
/// This is the const version of bottomLeftCorner<int, int>().
template<int CRows, int CCols>
-EIGEN_DEVICE_FUNC
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomLeftCorner() const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomLeftCorner() const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), rows() - CRows, 0);
}
-/// \returns an expression of a bottom-left corner of *this.
+/// \returns an expression of a bottom-left corner of \c *this.
///
/// \tparam CRows number of rows in corner as specified at compile-time
/// \tparam CCols number of columns in corner as specified at compile-time
@@ -403,45 +528,66 @@
/// \sa class Block
///
template<int CRows, int CCols>
-inline typename FixedBlockXpr<CRows,CCols>::Type bottomLeftCorner(Index cRows, Index cCols)
+EIGEN_STRONG_INLINE
+typename FixedBlockXpr<CRows,CCols>::Type bottomLeftCorner(Index cRows, Index cCols)
{
return typename FixedBlockXpr<CRows,CCols>::Type(derived(), rows() - cRows, 0, cRows, cCols);
}
/// This is the const version of bottomLeftCorner<int, int>(Index, Index).
template<int CRows, int CCols>
-inline const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomLeftCorner(Index cRows, Index cCols) const
+EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<CRows,CCols>::Type bottomLeftCorner(Index cRows, Index cCols) const
{
return typename ConstFixedBlockXpr<CRows,CCols>::Type(derived(), rows() - cRows, 0, cRows, cCols);
}
-/// \returns a block consisting of the top rows of *this.
+/// \returns a block consisting of the top rows of \c *this.
///
/// \param n the number of rows in the block
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
///
/// Example: \include MatrixBase_topRows_int.cpp
/// Output: \verbinclude MatrixBase_topRows_int.out
///
+/// The number of rows \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline RowsBlockXpr topRows(Index n)
+template<typename NRowsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename NRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+#else
+typename NRowsBlockXpr<...>::Type
+#endif
+topRows(NRowsType n)
{
- return RowsBlockXpr(derived(), 0, 0, n, cols());
+ return typename NRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+ (derived(), 0, 0, internal::get_runtime_value(n), cols());
}
-/// This is the const version of topRows(Index).
-EIGEN_DEVICE_FUNC
-inline ConstRowsBlockXpr topRows(Index n) const
+/// This is the const version of topRows(NRowsType).
+template<typename NRowsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstNRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+#else
+const typename ConstNRowsBlockXpr<...>::Type
+#endif
+topRows(NRowsType n) const
{
- return ConstRowsBlockXpr(derived(), 0, 0, n, cols());
+ return typename ConstNRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+ (derived(), 0, 0, internal::get_runtime_value(n), cols());
}
-/// \returns a block consisting of the top rows of *this.
+/// \returns a block consisting of the top rows of \c *this.
///
/// \tparam N the number of rows in the block as specified at compile-time
/// \param n the number of rows in the block as specified at run-time
@@ -454,50 +600,69 @@
///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename NRowsBlockXpr<N>::Type topRows(Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename NRowsBlockXpr<N>::Type topRows(Index n = N)
{
return typename NRowsBlockXpr<N>::Type(derived(), 0, 0, n, cols());
}
/// This is the const version of topRows<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstNRowsBlockXpr<N>::Type topRows(Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstNRowsBlockXpr<N>::Type topRows(Index n = N) const
{
return typename ConstNRowsBlockXpr<N>::Type(derived(), 0, 0, n, cols());
}
-/// \returns a block consisting of the bottom rows of *this.
+/// \returns a block consisting of the bottom rows of \c *this.
///
/// \param n the number of rows in the block
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
///
/// Example: \include MatrixBase_bottomRows_int.cpp
/// Output: \verbinclude MatrixBase_bottomRows_int.out
///
+/// The number of rows \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline RowsBlockXpr bottomRows(Index n)
+template<typename NRowsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename NRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+#else
+typename NRowsBlockXpr<...>::Type
+#endif
+bottomRows(NRowsType n)
{
- return RowsBlockXpr(derived(), rows() - n, 0, n, cols());
+ return typename NRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+ (derived(), rows() - internal::get_runtime_value(n), 0, internal::get_runtime_value(n), cols());
}
-/// This is the const version of bottomRows(Index).
-EIGEN_DEVICE_FUNC
-inline ConstRowsBlockXpr bottomRows(Index n) const
+/// This is the const version of bottomRows(NRowsType).
+template<typename NRowsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstNRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+#else
+const typename ConstNRowsBlockXpr<...>::Type
+#endif
+bottomRows(NRowsType n) const
{
- return ConstRowsBlockXpr(derived(), rows() - n, 0, n, cols());
+ return typename ConstNRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+ (derived(), rows() - internal::get_runtime_value(n), 0, internal::get_runtime_value(n), cols());
}
-/// \returns a block consisting of the bottom rows of *this.
+/// \returns a block consisting of the bottom rows of \c *this.
///
/// \tparam N the number of rows in the block as specified at compile-time
/// \param n the number of rows in the block as specified at run-time
@@ -510,51 +675,70 @@
///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename NRowsBlockXpr<N>::Type bottomRows(Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename NRowsBlockXpr<N>::Type bottomRows(Index n = N)
{
return typename NRowsBlockXpr<N>::Type(derived(), rows() - n, 0, n, cols());
}
/// This is the const version of bottomRows<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstNRowsBlockXpr<N>::Type bottomRows(Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstNRowsBlockXpr<N>::Type bottomRows(Index n = N) const
{
return typename ConstNRowsBlockXpr<N>::Type(derived(), rows() - n, 0, n, cols());
}
-/// \returns a block consisting of a range of rows of *this.
+/// \returns a block consisting of a range of rows of \c *this.
///
/// \param startRow the index of the first row in the block
/// \param n the number of rows in the block
+/// \tparam NRowsType the type of the value handling the number of rows in the block, typically Index.
///
/// Example: \include DenseBase_middleRows_int.cpp
/// Output: \verbinclude DenseBase_middleRows_int.out
///
+/// The number of rows \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline RowsBlockXpr middleRows(Index startRow, Index n)
+template<typename NRowsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename NRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+#else
+typename NRowsBlockXpr<...>::Type
+#endif
+middleRows(Index startRow, NRowsType n)
{
- return RowsBlockXpr(derived(), startRow, 0, n, cols());
+ return typename NRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+ (derived(), startRow, 0, internal::get_runtime_value(n), cols());
}
-/// This is the const version of middleRows(Index,Index).
-EIGEN_DEVICE_FUNC
-inline ConstRowsBlockXpr middleRows(Index startRow, Index n) const
+/// This is the const version of middleRows(Index,NRowsType).
+template<typename NRowsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstNRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+#else
+const typename ConstNRowsBlockXpr<...>::Type
+#endif
+middleRows(Index startRow, NRowsType n) const
{
- return ConstRowsBlockXpr(derived(), startRow, 0, n, cols());
+ return typename ConstNRowsBlockXpr<internal::get_fixed_value<NRowsType>::value>::Type
+ (derived(), startRow, 0, internal::get_runtime_value(n), cols());
}
-/// \returns a block consisting of a range of rows of *this.
+/// \returns a block consisting of a range of rows of \c *this.
///
/// \tparam N the number of rows in the block as specified at compile-time
/// \param startRow the index of the first row in the block
@@ -568,50 +752,69 @@
///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename NRowsBlockXpr<N>::Type middleRows(Index startRow, Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename NRowsBlockXpr<N>::Type middleRows(Index startRow, Index n = N)
{
return typename NRowsBlockXpr<N>::Type(derived(), startRow, 0, n, cols());
}
/// This is the const version of middleRows<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstNRowsBlockXpr<N>::Type middleRows(Index startRow, Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstNRowsBlockXpr<N>::Type middleRows(Index startRow, Index n = N) const
{
return typename ConstNRowsBlockXpr<N>::Type(derived(), startRow, 0, n, cols());
}
-/// \returns a block consisting of the left columns of *this.
+/// \returns a block consisting of the left columns of \c *this.
///
/// \param n the number of columns in the block
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
/// Example: \include MatrixBase_leftCols_int.cpp
/// Output: \verbinclude MatrixBase_leftCols_int.out
///
+/// The number of columns \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline ColsBlockXpr leftCols(Index n)
+template<typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename NColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename NColsBlockXpr<...>::Type
+#endif
+leftCols(NColsType n)
{
- return ColsBlockXpr(derived(), 0, 0, rows(), n);
+ return typename NColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, 0, rows(), internal::get_runtime_value(n));
}
-/// This is the const version of leftCols(Index).
-EIGEN_DEVICE_FUNC
-inline ConstColsBlockXpr leftCols(Index n) const
+/// This is the const version of leftCols(NColsType).
+template<typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstNColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstNColsBlockXpr<...>::Type
+#endif
+leftCols(NColsType n) const
{
- return ConstColsBlockXpr(derived(), 0, 0, rows(), n);
+ return typename ConstNColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, 0, rows(), internal::get_runtime_value(n));
}
-/// \returns a block consisting of the left columns of *this.
+/// \returns a block consisting of the left columns of \c *this.
///
/// \tparam N the number of columns in the block as specified at compile-time
/// \param n the number of columns in the block as specified at run-time
@@ -624,50 +827,69 @@
///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename NColsBlockXpr<N>::Type leftCols(Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename NColsBlockXpr<N>::Type leftCols(Index n = N)
{
return typename NColsBlockXpr<N>::Type(derived(), 0, 0, rows(), n);
}
/// This is the const version of leftCols<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstNColsBlockXpr<N>::Type leftCols(Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstNColsBlockXpr<N>::Type leftCols(Index n = N) const
{
return typename ConstNColsBlockXpr<N>::Type(derived(), 0, 0, rows(), n);
}
-/// \returns a block consisting of the right columns of *this.
+/// \returns a block consisting of the right columns of \c *this.
///
/// \param n the number of columns in the block
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
/// Example: \include MatrixBase_rightCols_int.cpp
/// Output: \verbinclude MatrixBase_rightCols_int.out
///
+/// The number of columns \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline ColsBlockXpr rightCols(Index n)
+template<typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename NColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename NColsBlockXpr<...>::Type
+#endif
+rightCols(NColsType n)
{
- return ColsBlockXpr(derived(), 0, cols() - n, rows(), n);
+ return typename NColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, cols() - internal::get_runtime_value(n), rows(), internal::get_runtime_value(n));
}
-/// This is the const version of rightCols(Index).
-EIGEN_DEVICE_FUNC
-inline ConstColsBlockXpr rightCols(Index n) const
+/// This is the const version of rightCols(NColsType).
+template<typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstNColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstNColsBlockXpr<...>::Type
+#endif
+rightCols(NColsType n) const
{
- return ConstColsBlockXpr(derived(), 0, cols() - n, rows(), n);
+ return typename ConstNColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, cols() - internal::get_runtime_value(n), rows(), internal::get_runtime_value(n));
}
-/// \returns a block consisting of the right columns of *this.
+/// \returns a block consisting of the right columns of \c *this.
///
/// \tparam N the number of columns in the block as specified at compile-time
/// \param n the number of columns in the block as specified at run-time
@@ -680,51 +902,70 @@
///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename NColsBlockXpr<N>::Type rightCols(Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename NColsBlockXpr<N>::Type rightCols(Index n = N)
{
return typename NColsBlockXpr<N>::Type(derived(), 0, cols() - n, rows(), n);
}
/// This is the const version of rightCols<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstNColsBlockXpr<N>::Type rightCols(Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstNColsBlockXpr<N>::Type rightCols(Index n = N) const
{
return typename ConstNColsBlockXpr<N>::Type(derived(), 0, cols() - n, rows(), n);
}
-/// \returns a block consisting of a range of columns of *this.
+/// \returns a block consisting of a range of columns of \c *this.
///
/// \param startCol the index of the first column in the block
/// \param numCols the number of columns in the block
+/// \tparam NColsType the type of the value handling the number of columns in the block, typically Index.
///
/// Example: \include DenseBase_middleCols_int.cpp
/// Output: \verbinclude DenseBase_middleCols_int.out
///
+/// The number of columns \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
-EIGEN_DEVICE_FUNC
-inline ColsBlockXpr middleCols(Index startCol, Index numCols)
+template<typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename NColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+#else
+typename NColsBlockXpr<...>::Type
+#endif
+middleCols(Index startCol, NColsType numCols)
{
- return ColsBlockXpr(derived(), 0, startCol, rows(), numCols);
+ return typename NColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, startCol, rows(), internal::get_runtime_value(numCols));
}
-/// This is the const version of middleCols(Index,Index).
-EIGEN_DEVICE_FUNC
-inline ConstColsBlockXpr middleCols(Index startCol, Index numCols) const
+/// This is the const version of middleCols(Index,NColsType).
+template<typename NColsType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstNColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+#else
+const typename ConstNColsBlockXpr<...>::Type
+#endif
+middleCols(Index startCol, NColsType numCols) const
{
- return ConstColsBlockXpr(derived(), 0, startCol, rows(), numCols);
+ return typename ConstNColsBlockXpr<internal::get_fixed_value<NColsType>::value>::Type
+ (derived(), 0, startCol, rows(), internal::get_runtime_value(numCols));
}
-/// \returns a block consisting of a range of columns of *this.
+/// \returns a block consisting of a range of columns of \c *this.
///
/// \tparam N the number of columns in the block as specified at compile-time
/// \param startCol the index of the first column in the block
@@ -738,26 +979,26 @@
///
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename NColsBlockXpr<N>::Type middleCols(Index startCol, Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename NColsBlockXpr<N>::Type middleCols(Index startCol, Index n = N)
{
return typename NColsBlockXpr<N>::Type(derived(), 0, startCol, rows(), n);
}
/// This is the const version of middleCols<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstNColsBlockXpr<N>::Type middleCols(Index startCol, Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstNColsBlockXpr<N>::Type middleCols(Index startCol, Index n = N) const
{
return typename ConstNColsBlockXpr<N>::Type(derived(), 0, startCol, rows(), n);
}
-/// \returns a fixed-size expression of a block in *this.
+/// \returns a fixed-size expression of a block of \c *this.
///
/// The template parameters \a NRows and \a NCols are the number of
/// rows and columns in the block.
@@ -768,29 +1009,35 @@
/// Example: \include MatrixBase_block_int_int.cpp
/// Output: \verbinclude MatrixBase_block_int_int.out
///
+/// \note The usage of of this overload is discouraged from %Eigen 3.4, better used the generic
+/// block(Index,Index,NRowsType,NColsType), here is the one-to-one equivalence:
+/// \code
+/// mat.template block<NRows,NCols>(i,j) <--> mat.block(i,j,fix<NRows>,fix<NCols>)
+/// \endcode
+///
/// \note since block is a templated member, the keyword template has to be used
/// if the matrix type is also a template parameter: \code m.template block<3,3>(1,1); \endcode
///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int NRows, int NCols>
-EIGEN_DEVICE_FUNC
-inline typename FixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol)
{
return typename FixedBlockXpr<NRows,NCols>::Type(derived(), startRow, startCol);
}
/// This is the const version of block<>(Index, Index). */
template<int NRows, int NCols>
-EIGEN_DEVICE_FUNC
-inline const typename ConstFixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol) const
{
return typename ConstFixedBlockXpr<NRows,NCols>::Type(derived(), startRow, startCol);
}
-/// \returns an expression of a block in *this.
+/// \returns an expression of a block of \c *this.
///
/// \tparam NRows number of rows in block as specified at compile-time
/// \tparam NCols number of columns in block as specified at compile-time
@@ -805,14 +1052,25 @@
/// \a NRows is \a Dynamic, and the same for the number of columns.
///
/// Example: \include MatrixBase_template_int_int_block_int_int_int_int.cpp
-/// Output: \verbinclude MatrixBase_template_int_int_block_int_int_int_int.cpp
+/// Output: \verbinclude MatrixBase_template_int_int_block_int_int_int_int.out
+///
+/// \note The usage of of this overload is discouraged from %Eigen 3.4, better used the generic
+/// block(Index,Index,NRowsType,NColsType), here is the one-to-one complete equivalence:
+/// \code
+/// mat.template block<NRows,NCols>(i,j,rows,cols) <--> mat.block(i,j,fix<NRows>(rows),fix<NCols>(cols))
+/// \endcode
+/// If we known that, e.g., NRows==Dynamic and NCols!=Dynamic, then the equivalence becomes:
+/// \code
+/// mat.template block<Dynamic,NCols>(i,j,rows,NCols) <--> mat.block(i,j,rows,fix<NCols>)
+/// \endcode
///
EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
///
-/// \sa class Block, block(Index,Index,Index,Index)
+/// \sa block(Index,Index,NRowsType,NColsType), class Block
///
template<int NRows, int NCols>
-inline typename FixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol,
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol,
Index blockRows, Index blockCols)
{
return typename FixedBlockXpr<NRows,NCols>::Type(derived(), startRow, startCol, blockRows, blockCols);
@@ -820,13 +1078,14 @@
/// This is the const version of block<>(Index, Index, Index, Index).
template<int NRows, int NCols>
-inline const typename ConstFixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol,
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const typename ConstFixedBlockXpr<NRows,NCols>::Type block(Index startRow, Index startCol,
Index blockRows, Index blockCols) const
{
return typename ConstFixedBlockXpr<NRows,NCols>::Type(derived(), startRow, startCol, blockRows, blockCols);
}
-/// \returns an expression of the \a i-th column of *this. Note that the numbering starts at 0.
+/// \returns an expression of the \a i-th column of \c *this. Note that the numbering starts at 0.
///
/// Example: \include MatrixBase_col.cpp
/// Output: \verbinclude MatrixBase_col.out
@@ -834,20 +1093,20 @@
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(column-major)
/**
* \sa row(), class Block */
-EIGEN_DEVICE_FUNC
-inline ColXpr col(Index i)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ColXpr col(Index i)
{
return ColXpr(derived(), i);
}
/// This is the const version of col().
-EIGEN_DEVICE_FUNC
-inline ConstColXpr col(Index i) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ConstColXpr col(Index i) const
{
return ConstColXpr(derived(), i);
}
-/// \returns an expression of the \a i-th row of *this. Note that the numbering starts at 0.
+/// \returns an expression of the \a i-th row of \c *this. Note that the numbering starts at 0.
///
/// Example: \include MatrixBase_row.cpp
/// Output: \verbinclude MatrixBase_row.out
@@ -855,109 +1114,166 @@
EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(row-major)
/**
* \sa col(), class Block */
-EIGEN_DEVICE_FUNC
-inline RowXpr row(Index i)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+RowXpr row(Index i)
{
return RowXpr(derived(), i);
}
/// This is the const version of row(). */
-EIGEN_DEVICE_FUNC
-inline ConstRowXpr row(Index i) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ConstRowXpr row(Index i) const
{
return ConstRowXpr(derived(), i);
}
-/// \returns a dynamic-size expression of a segment (i.e. a vector block) in *this.
+/// \returns an expression of a segment (i.e. a vector block) in \c *this with either dynamic or fixed sizes.
///
/// \only_for_vectors
///
/// \param start the first coefficient in the segment
/// \param n the number of coefficients in the segment
+/// \tparam NType the type of the value handling the number of coefficients in the segment, typically Index.
///
/// Example: \include MatrixBase_segment_int_int.cpp
/// Output: \verbinclude MatrixBase_segment_int_int.out
///
-/// \note Even though the returned expression has dynamic size, in the case
+/// The number of coefficients \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
+/// \note Even in the case that the returned expression has dynamic size, in the case
/// when it is applied to a fixed-size vector, it inherits a fixed maximal size,
/// which means that evaluating it does not cause a dynamic memory allocation.
///
-/// \sa class Block, segment(Index)
+/// \sa block(Index,Index,NRowsType,NColsType), fix<N>, fix<N>(int), class Block
///
-EIGEN_DEVICE_FUNC
-inline SegmentReturnType segment(Index start, Index n)
+template<typename NType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+#else
+typename FixedSegmentReturnType<...>::Type
+#endif
+segment(Index start, NType n)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return SegmentReturnType(derived(), start, n);
+ return typename FixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+ (derived(), start, internal::get_runtime_value(n));
}
-/// This is the const version of segment(Index,Index).
-EIGEN_DEVICE_FUNC
-inline ConstSegmentReturnType segment(Index start, Index n) const
+/// This is the const version of segment(Index,NType).
+template<typename NType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+#else
+const typename ConstFixedSegmentReturnType<...>::Type
+#endif
+segment(Index start, NType n) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return ConstSegmentReturnType(derived(), start, n);
+ return typename ConstFixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+ (derived(), start, internal::get_runtime_value(n));
}
-/// \returns a dynamic-size expression of the first coefficients of *this.
+/// \returns an expression of the first coefficients of \c *this with either dynamic or fixed sizes.
///
/// \only_for_vectors
///
/// \param n the number of coefficients in the segment
+/// \tparam NType the type of the value handling the number of coefficients in the segment, typically Index.
///
/// Example: \include MatrixBase_start_int.cpp
/// Output: \verbinclude MatrixBase_start_int.out
///
-/// \note Even though the returned expression has dynamic size, in the case
+/// The number of coefficients \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
+/// \note Even in the case that the returned expression has dynamic size, in the case
/// when it is applied to a fixed-size vector, it inherits a fixed maximal size,
/// which means that evaluating it does not cause a dynamic memory allocation.
///
/// \sa class Block, block(Index,Index)
///
-EIGEN_DEVICE_FUNC
-inline SegmentReturnType head(Index n)
+template<typename NType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+#else
+typename FixedSegmentReturnType<...>::Type
+#endif
+head(NType n)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return SegmentReturnType(derived(), 0, n);
+ return typename FixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+ (derived(), 0, internal::get_runtime_value(n));
}
-/// This is the const version of head(Index).
-EIGEN_DEVICE_FUNC
-inline ConstSegmentReturnType head(Index n) const
+/// This is the const version of head(NType).
+template<typename NType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+#else
+const typename ConstFixedSegmentReturnType<...>::Type
+#endif
+head(NType n) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return ConstSegmentReturnType(derived(), 0, n);
+ return typename ConstFixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+ (derived(), 0, internal::get_runtime_value(n));
}
-/// \returns a dynamic-size expression of the last coefficients of *this.
+/// \returns an expression of a last coefficients of \c *this with either dynamic or fixed sizes.
///
/// \only_for_vectors
///
/// \param n the number of coefficients in the segment
+/// \tparam NType the type of the value handling the number of coefficients in the segment, typically Index.
///
/// Example: \include MatrixBase_end_int.cpp
/// Output: \verbinclude MatrixBase_end_int.out
///
-/// \note Even though the returned expression has dynamic size, in the case
+/// The number of coefficients \a n can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments.
+/// See \link block(Index,Index,NRowsType,NColsType) block() \endlink for the details.
+///
+/// \note Even in the case that the returned expression has dynamic size, in the case
/// when it is applied to a fixed-size vector, it inherits a fixed maximal size,
/// which means that evaluating it does not cause a dynamic memory allocation.
///
/// \sa class Block, block(Index,Index)
///
-EIGEN_DEVICE_FUNC
-inline SegmentReturnType tail(Index n)
+template<typename NType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+typename FixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+#else
+typename FixedSegmentReturnType<...>::Type
+#endif
+tail(NType n)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return SegmentReturnType(derived(), this->size() - n, n);
+ return typename FixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+ (derived(), this->size() - internal::get_runtime_value(n), internal::get_runtime_value(n));
}
/// This is the const version of tail(Index).
-EIGEN_DEVICE_FUNC
-inline ConstSegmentReturnType tail(Index n) const
+template<typename NType>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+const typename ConstFixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+#else
+const typename ConstFixedSegmentReturnType<...>::Type
+#endif
+tail(NType n) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return ConstSegmentReturnType(derived(), this->size() - n, n);
+ return typename ConstFixedSegmentReturnType<internal::get_fixed_value<NType>::value>::Type
+ (derived(), this->size() - internal::get_runtime_value(n), internal::get_runtime_value(n));
}
/// \returns a fixed-size expression of a segment (i.e. a vector block) in \c *this
@@ -974,11 +1290,11 @@
/// Example: \include MatrixBase_template_int_segment.cpp
/// Output: \verbinclude MatrixBase_template_int_segment.out
///
-/// \sa class Block
+/// \sa segment(Index,NType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename FixedSegmentReturnType<N>::Type segment(Index start, Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedSegmentReturnType<N>::Type segment(Index start, Index n = N)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return typename FixedSegmentReturnType<N>::Type(derived(), start, n);
@@ -986,14 +1302,14 @@
/// This is the const version of segment<int>(Index).
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstFixedSegmentReturnType<N>::Type segment(Index start, Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstFixedSegmentReturnType<N>::Type segment(Index start, Index n = N) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return typename ConstFixedSegmentReturnType<N>::Type(derived(), start, n);
}
-/// \returns a fixed-size expression of the first coefficients of *this.
+/// \returns a fixed-size expression of the first coefficients of \c *this.
///
/// \only_for_vectors
///
@@ -1006,11 +1322,11 @@
/// Example: \include MatrixBase_template_int_start.cpp
/// Output: \verbinclude MatrixBase_template_int_start.out
///
-/// \sa class Block
+/// \sa head(NType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename FixedSegmentReturnType<N>::Type head(Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedSegmentReturnType<N>::Type head(Index n = N)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return typename FixedSegmentReturnType<N>::Type(derived(), 0, n);
@@ -1018,14 +1334,14 @@
/// This is the const version of head<int>().
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstFixedSegmentReturnType<N>::Type head(Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstFixedSegmentReturnType<N>::Type head(Index n = N) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return typename ConstFixedSegmentReturnType<N>::Type(derived(), 0, n);
}
-/// \returns a fixed-size expression of the last coefficients of *this.
+/// \returns a fixed-size expression of the last coefficients of \c *this.
///
/// \only_for_vectors
///
@@ -1038,11 +1354,11 @@
/// Example: \include MatrixBase_template_int_end.cpp
/// Output: \verbinclude MatrixBase_template_int_end.out
///
-/// \sa class Block
+/// \sa tail(NType), class Block
///
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename FixedSegmentReturnType<N>::Type tail(Index n = N)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename FixedSegmentReturnType<N>::Type tail(Index n = N)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return typename FixedSegmentReturnType<N>::Type(derived(), size() - n);
@@ -1050,9 +1366,77 @@
/// This is the const version of tail<int>.
template<int N>
-EIGEN_DEVICE_FUNC
-inline typename ConstFixedSegmentReturnType<N>::Type tail(Index n = N) const
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename ConstFixedSegmentReturnType<N>::Type tail(Index n = N) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return typename ConstFixedSegmentReturnType<N>::Type(derived(), size() - n);
}
+
+/// \returns the \a outer -th column (resp. row) of the matrix \c *this if \c *this
+/// is col-major (resp. row-major).
+///
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+InnerVectorReturnType innerVector(Index outer)
+{ return InnerVectorReturnType(derived(), outer); }
+
+/// \returns the \a outer -th column (resp. row) of the matrix \c *this if \c *this
+/// is col-major (resp. row-major). Read-only.
+///
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const ConstInnerVectorReturnType innerVector(Index outer) const
+{ return ConstInnerVectorReturnType(derived(), outer); }
+
+/// \returns the \a outer -th column (resp. row) of the matrix \c *this if \c *this
+/// is col-major (resp. row-major).
+///
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+InnerVectorsReturnType
+innerVectors(Index outerStart, Index outerSize)
+{
+ return Block<Derived,Dynamic,Dynamic,true>(derived(),
+ IsRowMajor ? outerStart : 0, IsRowMajor ? 0 : outerStart,
+ IsRowMajor ? outerSize : rows(), IsRowMajor ? cols() : outerSize);
+
+}
+
+/// \returns the \a outer -th column (resp. row) of the matrix \c *this if \c *this
+/// is col-major (resp. row-major). Read-only.
+///
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+const ConstInnerVectorsReturnType
+innerVectors(Index outerStart, Index outerSize) const
+{
+ return Block<const Derived,Dynamic,Dynamic,true>(derived(),
+ IsRowMajor ? outerStart : 0, IsRowMajor ? 0 : outerStart,
+ IsRowMajor ? outerSize : rows(), IsRowMajor ? cols() : outerSize);
+
+}
+
+/** \returns the i-th subvector (column or vector) according to the \c Direction
+ * \sa subVectors()
+ */
+template<DirectionType Direction>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename internal::conditional<Direction==Vertical,ColXpr,RowXpr>::type
+subVector(Index i)
+{
+ return typename internal::conditional<Direction==Vertical,ColXpr,RowXpr>::type(derived(),i);
+}
+
+/** This is the const version of subVector(Index) */
+template<DirectionType Direction>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+typename internal::conditional<Direction==Vertical,ConstColXpr,ConstRowXpr>::type
+subVector(Index i) const
+{
+ return typename internal::conditional<Direction==Vertical,ConstColXpr,ConstRowXpr>::type(derived(),i);
+}
+
+/** \returns the number of subvectors (rows or columns) in the direction \c Direction
+ * \sa subVector(Index)
+ */
+template<DirectionType Direction>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
+Index subVectors() const
+{ return (Direction==Vertical)?cols():rows(); }
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/CommonCwiseUnaryOps.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/CommonCwiseUnaryOps.h
index 89f4faa..5418dc4 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/CommonCwiseUnaryOps.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/CommonCwiseUnaryOps.h
@@ -76,6 +76,20 @@
return ConjugateReturnType(derived());
}
+/// \returns an expression of the complex conjugate of \c *this if Cond==true, returns derived() otherwise.
+///
+EIGEN_DOC_UNARY_ADDONS(conjugate,complex conjugate)
+///
+/// \sa conjugate()
+template<bool Cond>
+EIGEN_DEVICE_FUNC
+inline typename internal::conditional<Cond,ConjugateReturnType,const Derived&>::type
+conjugateIf() const
+{
+ typedef typename internal::conditional<Cond,ConjugateReturnType,const Derived&>::type ReturnType;
+ return ReturnType(derived());
+}
+
/// \returns a read-only expression of the real part of \c *this.
///
EIGEN_DOC_UNARY_ADDONS(real,real part function)
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/IndexedViewMethods.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/IndexedViewMethods.h
new file mode 100644
index 0000000..5bfb19a
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/IndexedViewMethods.h
@@ -0,0 +1,262 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#if !defined(EIGEN_PARSED_BY_DOXYGEN)
+
+// This file is automatically included twice to generate const and non-const versions
+
+#ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS
+#define EIGEN_INDEXED_VIEW_METHOD_CONST const
+#define EIGEN_INDEXED_VIEW_METHOD_TYPE ConstIndexedViewType
+#else
+#define EIGEN_INDEXED_VIEW_METHOD_CONST
+#define EIGEN_INDEXED_VIEW_METHOD_TYPE IndexedViewType
+#endif
+
+#ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS
+protected:
+
+// define some aliases to ease readability
+
+template<typename Indices>
+struct IvcRowType : public internal::IndexedViewCompatibleType<Indices,RowsAtCompileTime> {};
+
+template<typename Indices>
+struct IvcColType : public internal::IndexedViewCompatibleType<Indices,ColsAtCompileTime> {};
+
+template<typename Indices>
+struct IvcType : public internal::IndexedViewCompatibleType<Indices,SizeAtCompileTime> {};
+
+typedef typename internal::IndexedViewCompatibleType<Index,1>::type IvcIndex;
+
+template<typename Indices>
+typename IvcRowType<Indices>::type
+ivcRow(const Indices& indices) const {
+ return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,RowsAtCompileTime>(derived().rows()),Specialized);
+}
+
+template<typename Indices>
+typename IvcColType<Indices>::type
+ivcCol(const Indices& indices) const {
+ return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,ColsAtCompileTime>(derived().cols()),Specialized);
+}
+
+template<typename Indices>
+typename IvcColType<Indices>::type
+ivcSize(const Indices& indices) const {
+ return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,SizeAtCompileTime>(derived().size()),Specialized);
+}
+
+public:
+
+#endif
+
+template<typename RowIndices, typename ColIndices>
+struct EIGEN_INDEXED_VIEW_METHOD_TYPE {
+ typedef IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,
+ typename IvcRowType<RowIndices>::type,
+ typename IvcColType<ColIndices>::type> type;
+};
+
+// This is the generic version
+
+template<typename RowIndices, typename ColIndices>
+typename internal::enable_if<internal::valid_indexed_view_overload<RowIndices,ColIndices>::value
+ && internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::ReturnAsIndexedView,
+ typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type >::type
+operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ return typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type
+ (derived(), ivcRow(rowIndices), ivcCol(colIndices));
+}
+
+// The following overload returns a Block<> object
+
+template<typename RowIndices, typename ColIndices>
+typename internal::enable_if<internal::valid_indexed_view_overload<RowIndices,ColIndices>::value
+ && internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::ReturnAsBlock,
+ typename internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::BlockType>::type
+operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ typedef typename internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::BlockType BlockType;
+ typename IvcRowType<RowIndices>::type actualRowIndices = ivcRow(rowIndices);
+ typename IvcColType<ColIndices>::type actualColIndices = ivcCol(colIndices);
+ return BlockType(derived(),
+ internal::first(actualRowIndices),
+ internal::first(actualColIndices),
+ internal::size(actualRowIndices),
+ internal::size(actualColIndices));
+}
+
+// The following overload returns a Scalar
+
+template<typename RowIndices, typename ColIndices>
+typename internal::enable_if<internal::valid_indexed_view_overload<RowIndices,ColIndices>::value
+ && internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::ReturnAsScalar,
+ CoeffReturnType >::type
+operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ return Base::operator()(internal::eval_expr_given_size(rowIndices,rows()),internal::eval_expr_given_size(colIndices,cols()));
+}
+
+#if EIGEN_HAS_STATIC_ARRAY_TEMPLATE
+
+// The following three overloads are needed to handle raw Index[N] arrays.
+
+template<typename RowIndicesT, std::size_t RowIndicesN, typename ColIndices>
+IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,const RowIndicesT (&)[RowIndicesN],typename IvcColType<ColIndices>::type>
+operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,const RowIndicesT (&)[RowIndicesN],typename IvcColType<ColIndices>::type>
+ (derived(), rowIndices, ivcCol(colIndices));
+}
+
+template<typename RowIndices, typename ColIndicesT, std::size_t ColIndicesN>
+IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,typename IvcRowType<RowIndices>::type, const ColIndicesT (&)[ColIndicesN]>
+operator()(const RowIndices& rowIndices, const ColIndicesT (&colIndices)[ColIndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,typename IvcRowType<RowIndices>::type,const ColIndicesT (&)[ColIndicesN]>
+ (derived(), ivcRow(rowIndices), colIndices);
+}
+
+template<typename RowIndicesT, std::size_t RowIndicesN, typename ColIndicesT, std::size_t ColIndicesN>
+IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,const RowIndicesT (&)[RowIndicesN], const ColIndicesT (&)[ColIndicesN]>
+operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndicesT (&colIndices)[ColIndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,const RowIndicesT (&)[RowIndicesN],const ColIndicesT (&)[ColIndicesN]>
+ (derived(), rowIndices, colIndices);
+}
+
+#endif // EIGEN_HAS_STATIC_ARRAY_TEMPLATE
+
+// Overloads for 1D vectors/arrays
+
+template<typename Indices>
+typename internal::enable_if<
+ IsRowMajor && (!(internal::get_compile_time_incr<typename IvcType<Indices>::type>::value==1 || internal::is_valid_index_type<Indices>::value)),
+ IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,IvcIndex,typename IvcType<Indices>::type> >::type
+operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,IvcIndex,typename IvcType<Indices>::type>
+ (derived(), IvcIndex(0), ivcCol(indices));
+}
+
+template<typename Indices>
+typename internal::enable_if<
+ (!IsRowMajor) && (!(internal::get_compile_time_incr<typename IvcType<Indices>::type>::value==1 || internal::is_valid_index_type<Indices>::value)),
+ IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,typename IvcType<Indices>::type,IvcIndex> >::type
+operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,typename IvcType<Indices>::type,IvcIndex>
+ (derived(), ivcRow(indices), IvcIndex(0));
+}
+
+template<typename Indices>
+typename internal::enable_if<
+ (internal::get_compile_time_incr<typename IvcType<Indices>::type>::value==1) && (!internal::is_valid_index_type<Indices>::value) && (!symbolic::is_symbolic<Indices>::value),
+ VectorBlock<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,internal::array_size<Indices>::value> >::type
+operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
+ typename IvcType<Indices>::type actualIndices = ivcSize(indices);
+ return VectorBlock<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,internal::array_size<Indices>::value>
+ (derived(), internal::first(actualIndices), internal::size(actualIndices));
+}
+
+template<typename IndexType>
+typename internal::enable_if<symbolic::is_symbolic<IndexType>::value, CoeffReturnType >::type
+operator()(const IndexType& id) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ return Base::operator()(internal::eval_expr_given_size(id,size()));
+}
+
+#if EIGEN_HAS_STATIC_ARRAY_TEMPLATE
+
+template<typename IndicesT, std::size_t IndicesN>
+typename internal::enable_if<IsRowMajor,
+ IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,IvcIndex,const IndicesT (&)[IndicesN]> >::type
+operator()(const IndicesT (&indices)[IndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,IvcIndex,const IndicesT (&)[IndicesN]>
+ (derived(), IvcIndex(0), indices);
+}
+
+template<typename IndicesT, std::size_t IndicesN>
+typename internal::enable_if<!IsRowMajor,
+ IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,const IndicesT (&)[IndicesN],IvcIndex> >::type
+operator()(const IndicesT (&indices)[IndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST
+{
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
+ return IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,const IndicesT (&)[IndicesN],IvcIndex>
+ (derived(), indices, IvcIndex(0));
+}
+
+#endif // EIGEN_HAS_STATIC_ARRAY_TEMPLATE
+
+#undef EIGEN_INDEXED_VIEW_METHOD_CONST
+#undef EIGEN_INDEXED_VIEW_METHOD_TYPE
+
+#ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS
+#define EIGEN_INDEXED_VIEW_METHOD_2ND_PASS
+#include "IndexedViewMethods.h"
+#undef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS
+#endif
+
+#else // EIGEN_PARSED_BY_DOXYGEN
+
+/**
+ * \returns a generic submatrix view defined by the rows and columns indexed \a rowIndices and \a colIndices respectively.
+ *
+ * Each parameter must either be:
+ * - An integer indexing a single row or column
+ * - Eigen::all indexing the full set of respective rows or columns in increasing order
+ * - An ArithmeticSequence as returned by the Eigen::seq and Eigen::seqN functions
+ * - Any %Eigen's vector/array of integers or expressions
+ * - Plain C arrays: \c int[N]
+ * - And more generally any type exposing the following two member functions:
+ * \code
+ * <integral type> operator[](<integral type>) const;
+ * <integral type> size() const;
+ * \endcode
+ * where \c <integral \c type> stands for any integer type compatible with Eigen::Index (i.e. \c std::ptrdiff_t).
+ *
+ * The last statement implies compatibility with \c std::vector, \c std::valarray, \c std::array, many of the Range-v3's ranges, etc.
+ *
+ * If the submatrix can be represented using a starting position \c (i,j) and positive sizes \c (rows,columns), then this
+ * method will returns a Block object after extraction of the relevant information from the passed arguments. This is the case
+ * when all arguments are either:
+ * - An integer
+ * - Eigen::all
+ * - An ArithmeticSequence with compile-time increment strictly equal to 1, as returned by Eigen::seq(a,b), and Eigen::seqN(a,N).
+ *
+ * Otherwise a more general IndexedView<Derived,RowIndices',ColIndices'> object will be returned, after conversion of the inputs
+ * to more suitable types \c RowIndices' and \c ColIndices'.
+ *
+ * For 1D vectors and arrays, you better use the operator()(const Indices&) overload, which behave the same way but taking a single parameter.
+ *
+ * See also this <a href="https://stackoverflow.com/questions/46110917/eigen-replicate-items-along-one-dimension-without-useless-allocations">question</a> and its answer for an example of how to duplicate coefficients.
+ *
+ * \sa operator()(const Indices&), class Block, class IndexedView, DenseBase::block(Index,Index,Index,Index)
+ */
+template<typename RowIndices, typename ColIndices>
+IndexedView_or_Block
+operator()(const RowIndices& rowIndices, const ColIndices& colIndices);
+
+/** This is an overload of operator()(const RowIndices&, const ColIndices&) for 1D vectors or arrays
+ *
+ * \only_for_vectors
+ */
+template<typename Indices>
+IndexedView_or_VectorBlock
+operator()(const Indices& indices);
+
+#endif // EIGEN_PARSED_BY_DOXYGEN
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseBinaryOps.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseBinaryOps.h
index f1084ab..a0feef8 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseBinaryOps.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseBinaryOps.h
@@ -39,10 +39,10 @@
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
-inline const CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
+inline const CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
- return CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ return CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise != operator of *this and \a other
@@ -59,10 +59,10 @@
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
-inline const CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
+inline const CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
- return CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ return CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of *this and \a other
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseUnaryOps.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseUnaryOps.h
index b1be3d5..0514d8f 100644
--- a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseUnaryOps.h
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/MatrixCwiseUnaryOps.h
@@ -14,6 +14,7 @@
typedef CwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> CwiseAbsReturnType;
typedef CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const Derived> CwiseAbs2ReturnType;
+typedef CwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived> CwiseArgReturnType;
typedef CwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> CwiseSqrtReturnType;
typedef CwiseUnaryOp<internal::scalar_sign_op<Scalar>, const Derived> CwiseSignReturnType;
typedef CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> CwiseInverseReturnType;
@@ -82,4 +83,13 @@
inline const CwiseInverseReturnType
cwiseInverse() const { return CwiseInverseReturnType(derived()); }
+/// \returns an expression of the coefficient-wise phase angle of \c *this
+///
+/// Example: \include MatrixBase_cwiseArg.cpp
+/// Output: \verbinclude MatrixBase_cwiseArg.out
+///
+EIGEN_DOC_UNARY_ADDONS(cwiseArg,arg)
+EIGEN_DEVICE_FUNC
+inline const CwiseArgReturnType
+cwiseArg() const { return CwiseArgReturnType(derived()); }
diff --git a/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ReshapedMethods.h b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ReshapedMethods.h
new file mode 100644
index 0000000..482a6b0
--- /dev/null
+++ b/wpimath/src/main/native/eigeninclude/Eigen/src/plugins/ReshapedMethods.h
@@ -0,0 +1,149 @@
+
+#ifdef EIGEN_PARSED_BY_DOXYGEN
+
+/// \returns an expression of \c *this with reshaped sizes.
+///
+/// \param nRows the number of rows in the reshaped expression, specified at either run-time or compile-time, or AutoSize
+/// \param nCols the number of columns in the reshaped expression, specified at either run-time or compile-time, or AutoSize
+/// \tparam Order specifies whether the coefficients should be processed in column-major-order (ColMajor), in row-major-order (RowMajor),
+/// or follows the \em natural order of the nested expression (AutoOrder). The default is ColMajor.
+/// \tparam NRowsType the type of the value handling the number of rows, typically Index.
+/// \tparam NColsType the type of the value handling the number of columns, typically Index.
+///
+/// Dynamic size example: \include MatrixBase_reshaped_int_int.cpp
+/// Output: \verbinclude MatrixBase_reshaped_int_int.out
+///
+/// The number of rows \a nRows and columns \a nCols can also be specified at compile-time by passing Eigen::fix<N>,
+/// or Eigen::fix<N>(n) as arguments. In the later case, \c n plays the role of a runtime fallback value in case \c N equals Eigen::Dynamic.
+/// Here is an example with a fixed number of rows and columns:
+/// \include MatrixBase_reshaped_fixed.cpp
+/// Output: \verbinclude MatrixBase_reshaped_fixed.out
+///
+/// Finally, one of the sizes parameter can be automatically deduced from the other one by passing AutoSize as in the following example:
+/// \include MatrixBase_reshaped_auto.cpp
+/// Output: \verbinclude MatrixBase_reshaped_auto.out
+/// AutoSize does preserve compile-time sizes when possible, i.e., when the sizes of the input are known at compile time \b and
+/// that the other size is passed at compile-time using Eigen::fix<N> as above.
+///
+/// \sa class Reshaped, fix, fix<N>(int)
+///
+template<int Order = ColMajor, typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC
+inline Reshaped<Derived,...>
+reshaped(NRowsType nRows, NColsType nCols);
+
+/// This is the const version of reshaped(NRowsType,NColsType).
+template<int Order = ColMajor, typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC
+inline const Reshaped<const Derived,...>
+reshaped(NRowsType nRows, NColsType nCols) const;
+
+/// \returns an expression of \c *this with columns (or rows) stacked to a linear column vector
+///
+/// \tparam Order specifies whether the coefficients should be processed in column-major-order (ColMajor), in row-major-order (RowMajor),
+/// or follows the \em natural order of the nested expression (AutoOrder). The default is ColMajor.
+///
+/// This overloads is essentially a shortcut for `A.reshaped<Order>(AutoSize,fix<1>)`.
+///
+/// - If `Order==ColMajor` (the default), then it returns a column-vector from the stacked columns of \c *this.
+/// - If `Order==RowMajor`, then it returns a column-vector from the stacked rows of \c *this.
+/// - If `Order==AutoOrder`, then it returns a column-vector with elements stacked following the storage order of \c *this.
+/// This mode is the recommended one when the particular ordering of the element is not relevant.
+///
+/// Example:
+/// \include MatrixBase_reshaped_to_vector.cpp
+/// Output: \verbinclude MatrixBase_reshaped_to_vector.out
+///
+/// If you want more control, you can still fall back to reshaped(NRowsType,NColsType).
+///
+/// \sa reshaped(NRowsType,NColsType), class Reshaped
+///
+template<int Order = ColMajor>
+EIGEN_DEVICE_FUNC
+inline Reshaped<Derived,...>
+reshaped();
+
+/// This is the const version of reshaped().
+template<int Order = ColMajor>
+EIGEN_DEVICE_FUNC
+inline const Reshaped<const Derived,...>
+reshaped() const;
+
+#else
+
+// This file is automatically included twice to generate const and non-const versions
+
+#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
+#define EIGEN_RESHAPED_METHOD_CONST const
+#else
+#define EIGEN_RESHAPED_METHOD_CONST
+#endif
+
+#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
+
+// This part is included once
+
+#endif
+
+template<typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC
+inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
+ internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
+ internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
+reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
+{
+ return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
+ internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
+ internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
+ (derived(),
+ internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
+ internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
+}
+
+template<int Order, typename NRowsType, typename NColsType>
+EIGEN_DEVICE_FUNC
+inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
+ internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
+ internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
+reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
+{
+ return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
+ internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
+ internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
+ (derived(),
+ internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
+ internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
+}
+
+// Views as linear vectors
+
+EIGEN_DEVICE_FUNC
+inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>
+reshaped() EIGEN_RESHAPED_METHOD_CONST
+{
+ return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>(derived(),size(),1);
+}
+
+template<int Order>
+EIGEN_DEVICE_FUNC
+inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1,
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
+reshaped() EIGEN_RESHAPED_METHOD_CONST
+{
+ EIGEN_STATIC_ASSERT(Order==RowMajor || Order==ColMajor || Order==AutoOrder, INVALID_TEMPLATE_PARAMETER);
+ return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1,
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
+ (derived(), size(), 1);
+}
+
+#undef EIGEN_RESHAPED_METHOD_CONST
+
+#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
+#define EIGEN_RESHAPED_METHOD_2ND_PASS
+#include "ReshapedMethods.h"
+#undef EIGEN_RESHAPED_METHOD_2ND_PASS
+#endif
+
+#endif // EIGEN_PARSED_BY_DOXYGEN
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/AutoDiff b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/AutoDiff
similarity index 89%
rename from wpimath/src/main/native/include/unsupported/Eigen/AutoDiff
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/AutoDiff
index abf5b7d..7a4ff46 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/AutoDiff
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/AutoDiff
@@ -28,11 +28,17 @@
//@{
}
+#include "../../Eigen/src/Core/util/DisableStupidWarnings.h"
+
#include "src/AutoDiff/AutoDiffScalar.h"
// #include "src/AutoDiff/AutoDiffVector.h"
#include "src/AutoDiff/AutoDiffJacobian.h"
+#include "../../Eigen/src/Core/util/ReenableStupidWarnings.h"
+
+
+
namespace Eigen {
//@}
}
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/MatrixFunctions b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/MatrixFunctions
similarity index 98%
rename from wpimath/src/main/native/include/unsupported/Eigen/MatrixFunctions
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/MatrixFunctions
index 60dc0a6..20c23d1 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/MatrixFunctions
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/MatrixFunctions
@@ -14,9 +14,9 @@
#include <cfloat>
#include <list>
-#include <Eigen/Core>
-#include <Eigen/LU>
-#include <Eigen/Eigenvalues>
+#include "../../Eigen/Core"
+#include "../../Eigen/LU"
+#include "../../Eigen/Eigenvalues"
/**
* \defgroup MatrixFunctions_Module Matrix functions module
@@ -53,12 +53,16 @@
*
*/
+#include "../../Eigen/src/Core/util/DisableStupidWarnings.h"
+
#include "src/MatrixFunctions/MatrixExponential.h"
#include "src/MatrixFunctions/MatrixFunction.h"
#include "src/MatrixFunctions/MatrixSquareRoot.h"
#include "src/MatrixFunctions/MatrixLogarithm.h"
#include "src/MatrixFunctions/MatrixPower.h"
+#include "../../Eigen/src/Core/util/ReenableStupidWarnings.h"
+
/**
\page matrixbaseextra_page
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h
similarity index 100%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
similarity index 86%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
index 2f50e99..0f166e3 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
@@ -26,11 +26,11 @@
make_coherent_impl<A,B>::run(a.const_cast_derived(), b.const_cast_derived());
}
-template<typename _DerType, bool Enable> struct auto_diff_special_op;
+template<typename DerivativeType, bool Enable> struct auto_diff_special_op;
} // end namespace internal
-template<typename _DerType> class AutoDiffScalar;
+template<typename DerivativeType> class AutoDiffScalar;
template<typename NewDerType>
inline AutoDiffScalar<NewDerType> MakeAutoDiffScalar(const typename NewDerType::Scalar& value, const NewDerType &der) {
@@ -38,16 +38,16 @@
}
/** \class AutoDiffScalar
- * \brief A scalar type replacement with automatic differentation capability
+ * \brief A scalar type replacement with automatic differentiation capability
*
- * \param _DerType the vector type used to store/represent the derivatives. The base scalar type
+ * \param DerivativeType the vector type used to store/represent the derivatives. The base scalar type
* as well as the number of derivatives to compute are determined from this type.
* Typical choices include, e.g., \c Vector4f for 4 derivatives, or \c VectorXf
* if the number of derivatives is not known at compile time, and/or, the number
* of derivatives is large.
- * Note that _DerType can also be a reference (e.g., \c VectorXf&) to wrap a
+ * Note that DerivativeType can also be a reference (e.g., \c VectorXf&) to wrap a
* existing vector into an AutoDiffScalar.
- * Finally, _DerType can also be any Eigen compatible expression.
+ * Finally, DerivativeType can also be any Eigen compatible expression.
*
* This class represents a scalar value while tracking its respective derivatives using Eigen's expression
* template mechanism.
@@ -63,17 +63,17 @@
*
*/
-template<typename _DerType>
+template<typename DerivativeType>
class AutoDiffScalar
: public internal::auto_diff_special_op
- <_DerType, !internal::is_same<typename internal::traits<typename internal::remove_all<_DerType>::type>::Scalar,
- typename NumTraits<typename internal::traits<typename internal::remove_all<_DerType>::type>::Scalar>::Real>::value>
+ <DerivativeType, !internal::is_same<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar,
+ typename NumTraits<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar>::Real>::value>
{
public:
typedef internal::auto_diff_special_op
- <_DerType, !internal::is_same<typename internal::traits<typename internal::remove_all<_DerType>::type>::Scalar,
- typename NumTraits<typename internal::traits<typename internal::remove_all<_DerType>::type>::Scalar>::Real>::value> Base;
- typedef typename internal::remove_all<_DerType>::type DerType;
+ <DerivativeType, !internal::is_same<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar,
+ typename NumTraits<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar>::Real>::value> Base;
+ typedef typename internal::remove_all<DerivativeType>::type DerType;
typedef typename internal::traits<DerType>::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real Real;
@@ -382,16 +382,16 @@
namespace internal {
-template<typename _DerType>
-struct auto_diff_special_op<_DerType, true>
-// : auto_diff_scalar_op<_DerType, typename NumTraits<Scalar>::Real,
+template<typename DerivativeType>
+struct auto_diff_special_op<DerivativeType, true>
+// : auto_diff_scalar_op<DerivativeType, typename NumTraits<Scalar>::Real,
// is_same<Scalar,typename NumTraits<Scalar>::Real>::value>
{
- typedef typename remove_all<_DerType>::type DerType;
+ typedef typename remove_all<DerivativeType>::type DerType;
typedef typename traits<DerType>::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real Real;
-// typedef auto_diff_scalar_op<_DerType, typename NumTraits<Scalar>::Real,
+// typedef auto_diff_scalar_op<DerivativeType, typename NumTraits<Scalar>::Real,
// is_same<Scalar,typename NumTraits<Scalar>::Real>::value> Base;
// using Base::operator+;
@@ -401,8 +401,8 @@
// using Base::operator*;
// using Base::operator*=;
- const AutoDiffScalar<_DerType>& derived() const { return *static_cast<const AutoDiffScalar<_DerType>*>(this); }
- AutoDiffScalar<_DerType>& derived() { return *static_cast<AutoDiffScalar<_DerType>*>(this); }
+ const AutoDiffScalar<DerivativeType>& derived() const { return *static_cast<const AutoDiffScalar<DerivativeType>*>(this); }
+ AutoDiffScalar<DerivativeType>& derived() { return *static_cast<AutoDiffScalar<DerivativeType>*>(this); }
inline const AutoDiffScalar<DerType&> operator+(const Real& other) const
@@ -410,12 +410,12 @@
return AutoDiffScalar<DerType&>(derived().value() + other, derived().derivatives());
}
- friend inline const AutoDiffScalar<DerType&> operator+(const Real& a, const AutoDiffScalar<_DerType>& b)
+ friend inline const AutoDiffScalar<DerType&> operator+(const Real& a, const AutoDiffScalar<DerivativeType>& b)
{
return AutoDiffScalar<DerType&>(a + b.value(), b.derivatives());
}
- inline AutoDiffScalar<_DerType>& operator+=(const Real& other)
+ inline AutoDiffScalar<DerivativeType>& operator+=(const Real& other)
{
derived().value() += other;
return derived();
@@ -431,28 +431,46 @@
}
friend inline const AutoDiffScalar<typename CwiseUnaryOp<bind1st_op<scalar_product_op<Real,Scalar> >, DerType>::Type >
- operator*(const Real& other, const AutoDiffScalar<_DerType>& a)
+ operator*(const Real& other, const AutoDiffScalar<DerivativeType>& a)
{
return AutoDiffScalar<typename CwiseUnaryOp<bind1st_op<scalar_product_op<Real,Scalar> >, DerType>::Type >(
a.value() * other,
a.derivatives() * other);
}
- inline AutoDiffScalar<_DerType>& operator*=(const Scalar& other)
+ inline AutoDiffScalar<DerivativeType>& operator*=(const Scalar& other)
{
*this = *this * other;
return derived();
}
};
-template<typename _DerType>
-struct auto_diff_special_op<_DerType, false>
+template<typename DerivativeType>
+struct auto_diff_special_op<DerivativeType, false>
{
void operator*() const;
void operator-() const;
void operator+() const;
};
+template<typename BinOp, typename A, typename B, typename RefType>
+void make_coherent_expression(CwiseBinaryOp<BinOp,A,B> xpr, const RefType &ref)
+{
+ make_coherent(xpr.const_cast_derived().lhs(), ref);
+ make_coherent(xpr.const_cast_derived().rhs(), ref);
+}
+
+template<typename UnaryOp, typename A, typename RefType>
+void make_coherent_expression(const CwiseUnaryOp<UnaryOp,A> &xpr, const RefType &ref)
+{
+ make_coherent(xpr.nestedExpression().const_cast_derived(), ref);
+}
+
+// needed for compilation only
+template<typename UnaryOp, typename A, typename RefType>
+void make_coherent_expression(const CwiseNullaryOp<UnaryOp,A> &, const RefType &)
+{}
+
template<typename A_Scalar, int A_Rows, int A_Cols, int A_Options, int A_MaxRows, int A_MaxCols, typename B>
struct make_coherent_impl<Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols>, B> {
typedef Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols> A;
@@ -462,6 +480,10 @@
a.resize(b.size());
a.setZero();
}
+ else if (B::SizeAtCompileTime==Dynamic && a.size()!=0 && b.size()==0)
+ {
+ make_coherent_expression(b,a);
+ }
}
};
@@ -474,13 +496,17 @@
b.resize(a.size());
b.setZero();
}
+ else if (A::SizeAtCompileTime==Dynamic && b.size()!=0 && a.size()==0)
+ {
+ make_coherent_expression(a,b);
+ }
}
};
template<typename A_Scalar, int A_Rows, int A_Cols, int A_Options, int A_MaxRows, int A_MaxCols,
typename B_Scalar, int B_Rows, int B_Cols, int B_Options, int B_MaxRows, int B_MaxCols>
struct make_coherent_impl<Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols>,
- Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> > {
+ Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> > {
typedef Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols> A;
typedef Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> B;
static void run(A& a, B& b) {
@@ -540,37 +566,42 @@
}
template<typename DerType>
+struct CleanedUpDerType {
+ typedef AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> type;
+};
+
+template<typename DerType>
inline const AutoDiffScalar<DerType>& conj(const AutoDiffScalar<DerType>& x) { return x; }
template<typename DerType>
inline const AutoDiffScalar<DerType>& real(const AutoDiffScalar<DerType>& x) { return x; }
template<typename DerType>
inline typename DerType::Scalar imag(const AutoDiffScalar<DerType>&) { return 0.; }
template<typename DerType, typename T>
-inline AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> (min)(const AutoDiffScalar<DerType>& x, const T& y) {
- typedef AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> ADS;
+inline typename CleanedUpDerType<DerType>::type (min)(const AutoDiffScalar<DerType>& x, const T& y) {
+ typedef typename CleanedUpDerType<DerType>::type ADS;
return (x <= y ? ADS(x) : ADS(y));
}
template<typename DerType, typename T>
-inline AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> (max)(const AutoDiffScalar<DerType>& x, const T& y) {
- typedef AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> ADS;
+inline typename CleanedUpDerType<DerType>::type (max)(const AutoDiffScalar<DerType>& x, const T& y) {
+ typedef typename CleanedUpDerType<DerType>::type ADS;
return (x >= y ? ADS(x) : ADS(y));
}
template<typename DerType, typename T>
-inline AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> (min)(const T& x, const AutoDiffScalar<DerType>& y) {
- typedef AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> ADS;
+inline typename CleanedUpDerType<DerType>::type (min)(const T& x, const AutoDiffScalar<DerType>& y) {
+ typedef typename CleanedUpDerType<DerType>::type ADS;
return (x < y ? ADS(x) : ADS(y));
}
template<typename DerType, typename T>
-inline AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> (max)(const T& x, const AutoDiffScalar<DerType>& y) {
- typedef AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> ADS;
+inline typename CleanedUpDerType<DerType>::type (max)(const T& x, const AutoDiffScalar<DerType>& y) {
+ typedef typename CleanedUpDerType<DerType>::type ADS;
return (x > y ? ADS(x) : ADS(y));
}
template<typename DerType>
-inline AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> (min)(const AutoDiffScalar<DerType>& x, const AutoDiffScalar<DerType>& y) {
+inline typename CleanedUpDerType<DerType>::type (min)(const AutoDiffScalar<DerType>& x, const AutoDiffScalar<DerType>& y) {
return (x.value() < y.value() ? x : y);
}
template<typename DerType>
-inline AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> (max)(const AutoDiffScalar<DerType>& x, const AutoDiffScalar<DerType>& y) {
+inline typename CleanedUpDerType<DerType>::type (max)(const AutoDiffScalar<DerType>& x, const AutoDiffScalar<DerType>& y) {
return (x.value() >= y.value() ? x : y);
}
@@ -685,10 +716,15 @@
}
namespace std {
+
template <typename T>
class numeric_limits<Eigen::AutoDiffScalar<T> >
: public numeric_limits<typename T::Scalar> {};
+template <typename T>
+class numeric_limits<Eigen::AutoDiffScalar<T&> >
+ : public numeric_limits<typename T::Scalar> {};
+
} // namespace std
#endif // EIGEN_AUTODIFF_SCALAR_H
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffVector.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffVector.h
similarity index 100%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/AutoDiff/AutoDiffVector.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/AutoDiff/AutoDiffVector.h
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h
similarity index 98%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h
index e5ebbcf..02284b0 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h
@@ -314,7 +314,7 @@
matrix_exp_pade17(A, U, V);
}
-#elif LDBL_MANT_DIG <= 112 // quadruple precison
+#elif LDBL_MANT_DIG <= 113 // quadruple precision
if (l1norm < 1.639394610288918690547467954466970e-005L) {
matrix_exp_pade3(arg, U, V);
@@ -347,7 +347,7 @@
template<typename T> struct is_exp_known_type : false_type {};
template<> struct is_exp_known_type<float> : true_type {};
template<> struct is_exp_known_type<double> : true_type {};
-#if LDBL_MANT_DIG <= 112
+#if LDBL_MANT_DIG <= 113
template<> struct is_exp_known_type<long double> : true_type {};
#endif
@@ -396,7 +396,6 @@
template<typename Derived> struct MatrixExponentialReturnValue
: public ReturnByValue<MatrixExponentialReturnValue<Derived> >
{
- typedef typename Derived::Index Index;
public:
/** \brief Constructor.
*
@@ -412,7 +411,7 @@
inline void evalTo(ResultType& result) const
{
const typename internal::nested_eval<Derived, 10>::type tmp(m_src);
- internal::matrix_exp_compute(tmp, result, internal::is_exp_known_type<typename Derived::Scalar>());
+ internal::matrix_exp_compute(tmp, result, internal::is_exp_known_type<typename Derived::RealScalar>());
}
Index rows() const { return m_src.rows(); }
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h
similarity index 94%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h
index 3df8239..cc12ab6 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h
@@ -53,7 +53,7 @@
typename NumTraits<typename MatrixType::Scalar>::Real matrix_function_compute_mu(const MatrixType& A)
{
typedef typename plain_col_type<MatrixType>::type VectorType;
- typename MatrixType::Index rows = A.rows();
+ Index rows = A.rows();
const MatrixType N = MatrixType::Identity(rows, rows) - A;
VectorType e = VectorType::Ones(rows);
N.template triangularView<Upper>().solveInPlace(e);
@@ -65,7 +65,6 @@
{
// TODO: Use that A is upper triangular
typedef typename NumTraits<Scalar>::Real RealScalar;
- typedef typename MatrixType::Index Index;
Index rows = A.rows();
Scalar avgEival = A.trace() / Scalar(RealScalar(rows));
MatrixType Ashifted = A - avgEival * MatrixType::Identity(rows, rows);
@@ -73,10 +72,10 @@
MatrixType F = m_f(avgEival, 0) * MatrixType::Identity(rows, rows);
MatrixType P = Ashifted;
MatrixType Fincr;
- for (Index s = 1; s < 1.1 * rows + 10; s++) { // upper limit is fairly arbitrary
+ for (Index s = 1; double(s) < 1.1 * double(rows) + 10.0; s++) { // upper limit is fairly arbitrary
Fincr = m_f(avgEival, static_cast<int>(s)) * P;
F += Fincr;
- P = Scalar(RealScalar(1.0/(s + 1))) * P * Ashifted;
+ P = Scalar(RealScalar(1)/RealScalar(s + 1)) * P * Ashifted;
// test whether Taylor series converged
const RealScalar F_norm = F.cwiseAbs().rowwise().sum().maxCoeff();
@@ -131,7 +130,6 @@
template <typename EivalsType, typename Cluster>
void matrix_function_partition_eigenvalues(const EivalsType& eivals, std::list<Cluster>& clusters)
{
- typedef typename EivalsType::Index Index;
typedef typename EivalsType::RealScalar RealScalar;
for (Index i=0; i<eivals.rows(); ++i) {
// Find cluster containing i-th ei'val, adding a new cluster if necessary
@@ -179,7 +177,7 @@
{
blockStart.resize(clusterSize.rows());
blockStart(0) = 0;
- for (typename VectorType::Index i = 1; i < clusterSize.rows(); i++) {
+ for (Index i = 1; i < clusterSize.rows(); i++) {
blockStart(i) = blockStart(i-1) + clusterSize(i-1);
}
}
@@ -188,7 +186,6 @@
template <typename EivalsType, typename ListOfClusters, typename VectorType>
void matrix_function_compute_map(const EivalsType& eivals, const ListOfClusters& clusters, VectorType& eivalToCluster)
{
- typedef typename EivalsType::Index Index;
eivalToCluster.resize(eivals.rows());
Index clusterIndex = 0;
for (typename ListOfClusters::const_iterator cluster = clusters.begin(); cluster != clusters.end(); ++cluster) {
@@ -205,7 +202,6 @@
template <typename DynVectorType, typename VectorType>
void matrix_function_compute_permutation(const DynVectorType& blockStart, const DynVectorType& eivalToCluster, VectorType& permutation)
{
- typedef typename VectorType::Index Index;
DynVectorType indexNextEntry = blockStart;
permutation.resize(eivalToCluster.rows());
for (Index i = 0; i < eivalToCluster.rows(); i++) {
@@ -219,7 +215,6 @@
template <typename VectorType, typename MatrixType>
void matrix_function_permute_schur(VectorType& permutation, MatrixType& U, MatrixType& T)
{
- typedef typename VectorType::Index Index;
for (Index i = 0; i < permutation.rows() - 1; i++) {
Index j;
for (j = i; j < permutation.rows(); j++) {
@@ -247,7 +242,7 @@
void matrix_function_compute_block_atomic(const MatrixType& T, AtomicType& atomic, const VectorType& blockStart, const VectorType& clusterSize, MatrixType& fT)
{
fT.setZero(T.rows(), T.cols());
- for (typename VectorType::Index i = 0; i < clusterSize.rows(); ++i) {
+ for (Index i = 0; i < clusterSize.rows(); ++i) {
fT.block(blockStart(i), blockStart(i), clusterSize(i), clusterSize(i))
= atomic.compute(T.block(blockStart(i), blockStart(i), clusterSize(i), clusterSize(i)));
}
@@ -285,7 +280,6 @@
eigen_assert(C.rows() == A.rows());
eigen_assert(C.cols() == B.rows());
- typedef typename MatrixType::Index Index;
typedef typename MatrixType::Scalar Scalar;
Index m = A.rows();
@@ -330,11 +324,8 @@
{
typedef internal::traits<MatrixType> Traits;
typedef typename MatrixType::Scalar Scalar;
- typedef typename MatrixType::Index Index;
- static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
- static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
static const int Options = MatrixType::Options;
- typedef Matrix<Scalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
+ typedef Matrix<Scalar, Dynamic, Dynamic, Options, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime> DynMatrixType;
for (Index k = 1; k < clusterSize.rows(); k++) {
for (Index i = 0; i < clusterSize.rows() - k; i++) {
@@ -428,7 +419,8 @@
typedef internal::traits<MatrixType> Traits;
// compute Schur decomposition of A
- const ComplexSchur<MatrixType> schurOfA(A);
+ const ComplexSchur<MatrixType> schurOfA(A);
+ eigen_assert(schurOfA.info()==Success);
MatrixType T = schurOfA.matrixT();
MatrixType U = schurOfA.matrixU();
@@ -480,7 +472,6 @@
{
public:
typedef typename Derived::Scalar Scalar;
- typedef typename Derived::Index Index;
typedef typename internal::stem_function<Scalar>::type StemFunction;
protected:
@@ -505,10 +496,8 @@
typedef typename internal::nested_eval<Derived, 10>::type NestedEvalType;
typedef typename internal::remove_all<NestedEvalType>::type NestedEvalTypeClean;
typedef internal::traits<NestedEvalTypeClean> Traits;
- static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
- static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
- typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
+ typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime> DynMatrixType;
typedef internal::MatrixFunctionAtomic<DynMatrixType> AtomicType;
AtomicType atomic(m_f);
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
similarity index 95%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
index cf5fffa..e917013 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
@@ -62,8 +62,8 @@
else
{
// computation in previous branch is inaccurate if A(1,1) \approx A(0,0)
- int unwindingNumber = static_cast<int>(ceil((imag(logA11 - logA00) - RealScalar(EIGEN_PI)) / RealScalar(2*EIGEN_PI)));
- result(0,1) = A(0,1) * (numext::log1p(y/A(0,0)) + Scalar(0,2*EIGEN_PI*unwindingNumber)) / y;
+ RealScalar unwindingNumber = ceil((imag(logA11 - logA00) - RealScalar(EIGEN_PI)) / RealScalar(2*EIGEN_PI));
+ result(0,1) = A(0,1) * (numext::log1p(y/A(0,0)) + Scalar(0,RealScalar(2*EIGEN_PI)*unwindingNumber)) / y;
}
}
@@ -135,7 +135,8 @@
const int minPadeDegree = 3;
const int maxPadeDegree = 11;
assert(degree >= minPadeDegree && degree <= maxPadeDegree);
-
+ // FIXME this creates float-conversion-warnings if these are enabled.
+ // Either manually convert each value, or disable the warning locally
const RealScalar nodes[][maxPadeDegree] = {
{ 0.1127016653792583114820734600217600L, 0.5000000000000000000000000000000000L, // degree 3
0.8872983346207416885179265399782400L },
@@ -232,12 +233,13 @@
int degree;
MatrixType T = A, sqrtT;
- int maxPadeDegree = matrix_log_max_pade_degree<Scalar>::value;
- const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1L: // single precision
+ const int maxPadeDegree = matrix_log_max_pade_degree<Scalar>::value;
+ const RealScalar maxNormForPade = RealScalar(
+ maxPadeDegree<= 5? 5.3149729967117310e-1L: // single precision
maxPadeDegree<= 7? 2.6429608311114350e-1L: // double precision
maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision
maxPadeDegree<=10? 1.05026503471351080481093652651105e-1L: // double-double
- 1.1880960220216759245467951592883642e-1L; // quadruple precision
+ 1.1880960220216759245467951592883642e-1L); // quadruple precision
while (true) {
RealScalar normTminusI = (T - MatrixType::Identity(T.rows(), T.rows())).cwiseAbs().colwise().sum().maxCoeff();
@@ -254,7 +256,7 @@
}
matrix_log_compute_pade(result, T, degree);
- result *= pow(RealScalar(2), numberOfSquareRoots);
+ result *= pow(RealScalar(2), RealScalar(numberOfSquareRoots)); // TODO replace by bitshift if possible
}
/** \ingroup MatrixFunctions_Module
@@ -332,10 +334,8 @@
typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
typedef typename internal::remove_all<DerivedEvalType>::type DerivedEvalTypeClean;
typedef internal::traits<DerivedEvalTypeClean> Traits;
- static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
- static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
- typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
+ typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime> DynMatrixType;
typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType;
AtomicType atomic;
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
similarity index 96%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
index a3273da..d7672d7 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
@@ -40,7 +40,6 @@
{
public:
typedef typename MatrixType::RealScalar RealScalar;
- typedef typename MatrixType::Index Index;
/**
* \brief Constructor.
@@ -81,7 +80,7 @@
*
* \note Currently this class is only used by MatrixPower. One may
* insist that this be nested into MatrixPower. This class is here to
- * faciliate future development of triangular matrix functions.
+ * facilitate future development of triangular matrix functions.
*/
template<typename MatrixType>
class MatrixPowerAtomic : internal::noncopyable
@@ -94,7 +93,6 @@
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
typedef std::complex<RealScalar> ComplexScalar;
- typedef typename MatrixType::Index Index;
typedef Block<MatrixType,Dynamic,Dynamic> ResultType;
const MatrixType& m_A;
@@ -162,11 +160,11 @@
void MatrixPowerAtomic<MatrixType>::computePade(int degree, const MatrixType& IminusT, ResultType& res) const
{
int i = 2*degree;
- res = (m_p-degree) / (2*i-2) * IminusT;
+ res = (m_p-RealScalar(degree)) / RealScalar(2*i-2) * IminusT;
for (--i; i; --i) {
res = (MatrixType::Identity(IminusT.rows(), IminusT.cols()) + res).template triangularView<Upper>()
- .solve((i==1 ? -m_p : i&1 ? (-m_p-i/2)/(2*i) : (m_p-i/2)/(2*i-2)) * IminusT).eval();
+ .solve((i==1 ? -m_p : i&1 ? (-m_p-RealScalar(i/2))/RealScalar(2*i) : (m_p-RealScalar(i/2))/RealScalar(2*i-2)) * IminusT).eval();
}
res += MatrixType::Identity(IminusT.rows(), IminusT.cols());
}
@@ -196,11 +194,12 @@
{
using std::ldexp;
const int digits = std::numeric_limits<RealScalar>::digits;
- const RealScalar maxNormForPade = digits <= 24? 4.3386528e-1L // single precision
+ const RealScalar maxNormForPade = RealScalar(
+ digits <= 24? 4.3386528e-1L // single precision
: digits <= 53? 2.789358995219730e-1L // double precision
: digits <= 64? 2.4471944416607995472e-1L // extended precision
: digits <= 106? 1.1016843812851143391275867258512e-1L // double-double
- : 9.134603732914548552537150753385375e-2L; // quadruple precision
+ : 9.134603732914548552537150753385375e-2L); // quadruple precision
MatrixType IminusT, sqrtT, T = m_A.template triangularView<Upper>();
RealScalar normIminusT;
int degree, degree2, numberOfSquareRoots = 0;
@@ -298,8 +297,8 @@
ComplexScalar logCurr = log(curr);
ComplexScalar logPrev = log(prev);
- int unwindingNumber = ceil((numext::imag(logCurr - logPrev) - RealScalar(EIGEN_PI)) / RealScalar(2*EIGEN_PI));
- ComplexScalar w = numext::log1p((curr-prev)/prev)/RealScalar(2) + ComplexScalar(0, EIGEN_PI*unwindingNumber);
+ RealScalar unwindingNumber = ceil((numext::imag(logCurr - logPrev) - RealScalar(EIGEN_PI)) / RealScalar(2*EIGEN_PI));
+ ComplexScalar w = numext::log1p((curr-prev)/prev)/RealScalar(2) + ComplexScalar(0, RealScalar(EIGEN_PI)*unwindingNumber);
return RealScalar(2) * exp(RealScalar(0.5) * p * (logCurr + logPrev)) * sinh(p * w) / (curr - prev);
}
@@ -340,7 +339,6 @@
private:
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
- typedef typename MatrixType::Index Index;
public:
/**
@@ -600,7 +598,6 @@
public:
typedef typename Derived::PlainObject PlainObject;
typedef typename Derived::RealScalar RealScalar;
- typedef typename Derived::Index Index;
/**
* \brief Constructor.
@@ -648,7 +645,6 @@
public:
typedef typename Derived::PlainObject PlainObject;
typedef typename std::complex<typename Derived::RealScalar> ComplexScalar;
- typedef typename Derived::Index Index;
/**
* \brief Constructor.
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
similarity index 93%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
index 2e5abda..e363e77 100644
--- a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
+++ b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
@@ -17,7 +17,7 @@
// pre: T.block(i,i,2,2) has complex conjugate eigenvalues
// post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
template <typename MatrixType, typename ResultType>
-void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, typename MatrixType::Index i, ResultType& sqrtT)
+void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, Index i, ResultType& sqrtT)
{
// TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
// in EigenSolver. If we expose it, we could call it directly from here.
@@ -32,7 +32,7 @@
// all blocks of sqrtT to left of and below (i,j) are correct
// post: sqrtT(i,j) has the correct value
template <typename MatrixType, typename ResultType>
-void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT)
+void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
{
typedef typename traits<MatrixType>::Scalar Scalar;
Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
@@ -41,7 +41,7 @@
// similar to compute1x1offDiagonalBlock()
template <typename MatrixType, typename ResultType>
-void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT)
+void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
{
typedef typename traits<MatrixType>::Scalar Scalar;
Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
@@ -54,7 +54,7 @@
// similar to compute1x1offDiagonalBlock()
template <typename MatrixType, typename ResultType>
-void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT)
+void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
{
typedef typename traits<MatrixType>::Scalar Scalar;
Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
@@ -101,7 +101,7 @@
// similar to compute1x1offDiagonalBlock()
template <typename MatrixType, typename ResultType>
-void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT)
+void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT)
{
typedef typename traits<MatrixType>::Scalar Scalar;
Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
@@ -204,7 +204,7 @@
void matrix_sqrt_triangular(const MatrixType &arg, ResultType &result)
{
using std::sqrt;
- typedef typename MatrixType::Scalar Scalar;
+ typedef typename MatrixType::Scalar Scalar;
eigen_assert(arg.rows() == arg.cols());
@@ -253,18 +253,19 @@
template <typename MatrixType>
struct matrix_sqrt_compute<MatrixType, 0>
{
+ typedef typename MatrixType::PlainObject PlainType;
template <typename ResultType>
static void run(const MatrixType &arg, ResultType &result)
{
eigen_assert(arg.rows() == arg.cols());
// Compute Schur decomposition of arg
- const RealSchur<MatrixType> schurOfA(arg);
- const MatrixType& T = schurOfA.matrixT();
- const MatrixType& U = schurOfA.matrixU();
+ const RealSchur<PlainType> schurOfA(arg);
+ const PlainType& T = schurOfA.matrixT();
+ const PlainType& U = schurOfA.matrixU();
// Compute square root of T
- MatrixType sqrtT = MatrixType::Zero(arg.rows(), arg.cols());
+ PlainType sqrtT = PlainType::Zero(arg.rows(), arg.cols());
matrix_sqrt_quasi_triangular(T, sqrtT);
// Compute square root of arg
@@ -278,18 +279,19 @@
template <typename MatrixType>
struct matrix_sqrt_compute<MatrixType, 1>
{
+ typedef typename MatrixType::PlainObject PlainType;
template <typename ResultType>
static void run(const MatrixType &arg, ResultType &result)
{
eigen_assert(arg.rows() == arg.cols());
// Compute Schur decomposition of arg
- const ComplexSchur<MatrixType> schurOfA(arg);
- const MatrixType& T = schurOfA.matrixT();
- const MatrixType& U = schurOfA.matrixU();
+ const ComplexSchur<PlainType> schurOfA(arg);
+ const PlainType& T = schurOfA.matrixT();
+ const PlainType& U = schurOfA.matrixU();
// Compute square root of T
- MatrixType sqrtT;
+ PlainType sqrtT;
matrix_sqrt_triangular(T, sqrtT);
// Compute square root of arg
diff --git a/wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/StemFunction.h b/wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/StemFunction.h
similarity index 100%
rename from wpimath/src/main/native/include/unsupported/Eigen/src/MatrixFunctions/StemFunction.h
rename to wpimath/src/main/native/eigeninclude/unsupported/Eigen/src/MatrixFunctions/StemFunction.h
diff --git a/wpimath/src/main/native/include/drake/common/drake_assert.h b/wpimath/src/main/native/include/drake/common/drake_assert.h
index 21e7bd1..88587fa 100644
--- a/wpimath/src/main/native/include/drake/common/drake_assert.h
+++ b/wpimath/src/main/native/include/drake/common/drake_assert.h
@@ -83,12 +83,19 @@
namespace drake {
namespace internal {
// Abort the program with an error message.
-[[noreturn]]
-void Abort(const char* condition, const char* func, const char* file, int line);
+[[noreturn]] void Abort(const char* condition, const char* func,
+ const char* file, int line);
// Report an assertion failure; will either Abort(...) or throw.
-[[noreturn]]
-void AssertionFailed(
- const char* condition, const char* func, const char* file, int line);
+[[noreturn]] void AssertionFailed(const char* condition, const char* func,
+ const char* file, int line);
+template <bool>
+constexpr void DrakeAssertWasUsedWithRawPointer() {}
+template<>
+[[deprecated("\nDRAKE DEPRECATED: When using DRAKE_ASSERT or DRAKE_DEMAND on"
+" a raw pointer, always write out DRAKE_ASSERT(foo != nullptr), do not write"
+" DRAKE_ASSERT(foo) and rely on implicit pointer-to-bool conversion."
+"\nThe deprecated code will be removed from Drake on or after 2021-12-01.")]]
+constexpr void DrakeAssertWasUsedWithRawPointer<true>() {}
} // namespace internal
namespace assert {
// Allows for specialization of how to bool-convert Conditions used in
@@ -98,7 +105,7 @@
// require special handling.
template <typename Condition>
struct ConditionTraits {
- static constexpr bool is_valid = std::is_convertible<Condition, bool>::value;
+ static constexpr bool is_valid = std::is_convertible_v<Condition, bool>;
static bool Evaluate(const Condition& value) {
return value;
}
@@ -113,8 +120,10 @@
#define DRAKE_DEMAND(condition) \
do { \
typedef ::drake::assert::ConditionTraits< \
- typename std::remove_cv<decltype(condition)>::type> Trait; \
+ typename std::remove_cv_t<decltype(condition)>> Trait; \
static_assert(Trait::is_valid, "Condition should be bool-convertible."); \
+ ::drake::internal::DrakeAssertWasUsedWithRawPointer< \
+ std::is_pointer_v<decltype(condition)>>(); \
if (!Trait::Evaluate(condition)) { \
::drake::internal::AssertionFailed( \
#condition, __func__, __FILE__, __LINE__); \
@@ -130,7 +139,7 @@
# define DRAKE_ASSERT(condition) DRAKE_DEMAND(condition)
# define DRAKE_ASSERT_VOID(expression) do { \
static_assert( \
- std::is_convertible<decltype(expression), void>::value, \
+ std::is_convertible_v<decltype(expression), void>, \
"Expression should be void."); \
expression; \
} while (0)
@@ -140,12 +149,16 @@
constexpr bool kDrakeAssertIsArmed = false;
constexpr bool kDrakeAssertIsDisarmed = true;
} // namespace drake
-# define DRAKE_ASSERT(condition) static_assert( \
- ::drake::assert::ConditionTraits< \
- typename std::remove_cv<decltype(condition)>::type>::is_valid, \
- "Condition should be bool-convertible.");
+# define DRAKE_ASSERT(condition) do { \
+ static_assert( \
+ ::drake::assert::ConditionTraits< \
+ typename std::remove_cv_t<decltype(condition)>>::is_valid, \
+ "Condition should be bool-convertible."); \
+ ::drake::internal::DrakeAssertWasUsedWithRawPointer< \
+ std::is_pointer_v<decltype(condition)>>(); \
+ } while (0)
# define DRAKE_ASSERT_VOID(expression) static_assert( \
- std::is_convertible<decltype(expression), void>::value, \
+ std::is_convertible_v<decltype(expression), void>, \
"Expression should be void.")
#endif
diff --git a/wpimath/src/main/native/include/drake/common/drake_assertion_error.h b/wpimath/src/main/native/include/drake/common/drake_assertion_error.h
index 541b118..b428474 100644
--- a/wpimath/src/main/native/include/drake/common/drake_assertion_error.h
+++ b/wpimath/src/main/native/include/drake/common/drake_assertion_error.h
@@ -6,8 +6,8 @@
namespace drake {
namespace internal {
-/// This is what DRAKE_ASSERT and DRAKE_DEMAND throw when our assertions are
-/// configured to throw.
+// This is what DRAKE_ASSERT and DRAKE_DEMAND throw when our assertions are
+// configured to throw.
class assertion_error : public std::runtime_error {
public:
explicit assertion_error(const std::string& what_arg)
diff --git a/wpimath/src/main/native/include/drake/common/drake_copyable.h b/wpimath/src/main/native/include/drake/common/drake_copyable.h
index 086f1f7..a96a6fb 100644
--- a/wpimath/src/main/native/include/drake/common/drake_copyable.h
+++ b/wpimath/src/main/native/include/drake/common/drake_copyable.h
@@ -19,8 +19,8 @@
/** DRAKE_NO_COPY_NO_MOVE_NO_ASSIGN deletes the special member functions for
copy-construction, copy-assignment, move-construction, and move-assignment.
Drake's Doxygen is customized to render the deletions in detail, with
-appropriate comments. Invoke this this macro in the public section of the
-class declaration, e.g.:
+appropriate comments. Invoke this macro in the public section of the class
+declaration, e.g.:
<pre>
class Foo {
public:
@@ -43,8 +43,8 @@
functions could conceivably still be ill-formed, in which case they will
effectively not be declared or used -- but because the copy constructor exists
the type will still be MoveConstructible. Drake's Doxygen is customized to
-render the functions in detail, with appropriate comments. Invoke this this
-macro in the public section of the class declaration, e.g.:
+render the functions in detail, with appropriate comments. Typically, you
+should invoke this macro in the public section of the class declaration, e.g.:
<pre>
class Foo {
public:
@@ -53,60 +53,38 @@
// ...
};
</pre>
+
+However, if Foo has a virtual destructor (i.e., is subclassable), then
+typically you should use either DRAKE_NO_COPY_NO_MOVE_NO_ASSIGN in the
+public section or else DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN in the
+protected section, to prevent
+<a href="https://en.wikipedia.org/wiki/Object_slicing">object slicing</a>.
+
+The macro contains a built-in self-check that copy-assignment is well-formed.
+This self-check proves that the Classname is CopyConstructible, CopyAssignable,
+MoveConstructible, and MoveAssignable (in all but the most arcane cases where a
+member of the Classname is somehow CopyAssignable but not CopyConstructible).
+Therefore, classes that use this macro typically will not need to have any unit
+tests that check for the presence nor correctness of these functions.
+
+However, the self-check does not provide any checks of the runtime efficiency
+of the functions. If it is important for performance that the move functions
+actually move (instead of making a copy), then you should consider capturing
+that in a unit test.
*/
#define DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(Classname) \
Classname(const Classname&) = default; \
Classname& operator=(const Classname&) = default; \
Classname(Classname&&) = default; \
Classname& operator=(Classname&&) = default; \
- /* Fails at compile-time if default-copy doesn't work. */ \
- static void DRAKE_COPYABLE_DEMAND_COPY_CAN_COMPILE() { \
- (void) static_cast<Classname& (Classname::*)( \
- const Classname&)>(&Classname::operator=); \
- }
-
-/** DRAKE_DECLARE_COPY_AND_MOVE_AND_ASSIGN declares (but does not define) the
-special member functions for copy-construction, copy-assignment,
-move-construction, and move-assignment. Drake's Doxygen is customized to
-render the declarations with appropriate comments.
-
-This is useful when paired with DRAKE_DEFINE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN_T
-to work around https://gcc.gnu.org/bugzilla/show_bug.cgi?id=57728 whereby the
-declaration and definition must be split. Once Drake no longer supports GCC
-versions prior to 6.3, this macro could be removed.
-
-Invoke this this macro in the public section of the class declaration, e.g.:
-<pre>
-template <typename T>
-class Foo {
- public:
- DRAKE_DECLARE_COPY_AND_MOVE_AND_ASSIGN(Foo)
-
- // ...
-};
-DRAKE_DEFINE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN_T(Foo)
-</pre>
-*/
-#define DRAKE_DECLARE_COPY_AND_MOVE_AND_ASSIGN(Classname) \
- Classname(const Classname&); \
- Classname& operator=(const Classname&); \
- Classname(Classname&&); \
- Classname& operator=(Classname&&); \
- /* Fails at compile-time if default-copy doesn't work. */ \
- static void DRAKE_COPYABLE_DEMAND_COPY_CAN_COMPILE() { \
- (void) static_cast<Classname& (Classname::*)( \
- const Classname&)>(&Classname::operator=); \
- }
-
-/** Helper for DRAKE_DECLARE_COPY_AND_MOVE_AND_ASSIGN. Provides defaulted
-definitions for the four special member functions of a templated class.
-*/
-#define DRAKE_DEFINE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN_T(Classname) \
- template <typename T> \
- Classname<T>::Classname(const Classname<T>&) = default; \
- template <typename T> \
- Classname<T>& Classname<T>::operator=(const Classname<T>&) = default; \
- template <typename T> \
- Classname<T>::Classname(Classname<T>&&) = default; \
- template <typename T> \
- Classname<T>& Classname<T>::operator=(Classname<T>&&) = default;
+ /* Fails at compile-time if copy-assign doesn't compile. */ \
+ /* Note that we do not test the copy-ctor here, because */ \
+ /* it will not exist when Classname is abstract. */ \
+ static void DrakeDefaultCopyAndMoveAndAssign_DoAssign( \
+ Classname* a, const Classname& b) { *a = b; } \
+ static_assert( \
+ &DrakeDefaultCopyAndMoveAndAssign_DoAssign == \
+ &DrakeDefaultCopyAndMoveAndAssign_DoAssign, \
+ "This assertion is never false; its only purpose is to " \
+ "generate 'use of deleted function: operator=' errors " \
+ "when Classname is a template.");
diff --git a/wpimath/src/main/native/include/drake/common/drake_throw.h b/wpimath/src/main/native/include/drake/common/drake_throw.h
index ff42bb7..bb4bae8 100644
--- a/wpimath/src/main/native/include/drake/common/drake_throw.h
+++ b/wpimath/src/main/native/include/drake/common/drake_throw.h
@@ -12,19 +12,42 @@
namespace drake {
namespace internal {
// Throw an error message.
-[[noreturn]]
-void Throw(const char* condition, const char* func, const char* file, int line);
+[[noreturn]] void Throw(const char* condition, const char* func,
+ const char* file, int line);
+
+template <bool>
+constexpr void DrakeThrowUnlessWasUsedWithRawPointer() {}
+template<>
+[[deprecated("\nDRAKE DEPRECATED: When using DRAKE_THROW_UNLESS on a raw"
+" pointer, always write out DRAKE_THROW_UNLESS(foo != nullptr), do not write"
+" DRAKE_THROW_UNLESS(foo) and rely on implicit pointer-to-bool conversion."
+"\nThe deprecated code will be removed from Drake on or after 2021-12-01.")]]
+constexpr void DrakeThrowUnlessWasUsedWithRawPointer<true>() {}
+
} // namespace internal
} // namespace drake
/// Evaluates @p condition and iff the value is false will throw an exception
/// with a message showing at least the condition text, function name, file,
/// and line.
+///
+/// The condition must not be a pointer, where we'd implicitly rely on its
+/// nullness. Instead, always write out "!= nullptr" to be precise.
+///
+/// Correct: `DRAKE_THROW_UNLESS(foo != nullptr);`
+/// Incorrect: `DRAKE_THROW_UNLESS(foo);`
+///
+/// Because this macro is intended to provide a useful exception message to
+/// users, we should err on the side of extra detail about the failure. The
+/// meaning of "foo" isolated within error message text does not make it
+/// clear that a null pointer is the proximate cause of the problem.
#define DRAKE_THROW_UNLESS(condition) \
do { \
typedef ::drake::assert::ConditionTraits< \
- typename std::remove_cv<decltype(condition)>::type> Trait; \
+ typename std::remove_cv_t<decltype(condition)>> Trait; \
static_assert(Trait::is_valid, "Condition should be bool-convertible."); \
+ ::drake::internal::DrakeThrowUnlessWasUsedWithRawPointer< \
+ std::is_pointer_v<decltype(condition)>>(); \
if (!Trait::Evaluate(condition)) { \
::drake::internal::Throw(#condition, __func__, __FILE__, __LINE__); \
} \
diff --git a/wpimath/src/main/native/include/drake/common/never_destroyed.h b/wpimath/src/main/native/include/drake/common/never_destroyed.h
index 2033fd0..024b355 100644
--- a/wpimath/src/main/native/include/drake/common/never_destroyed.h
+++ b/wpimath/src/main/native/include/drake/common/never_destroyed.h
@@ -56,6 +56,25 @@
/// return string_to_enum.access().at(foo_string);
/// }
/// @endcode
+///
+/// In cases where computing the static data is more complicated than an
+/// initializer_list, you can use a temporary lambda to populate the value:
+/// @code
+/// const std::vector<double>& GetConstantMagicNumbers() {
+/// static const drake::never_destroyed<std::vector<double>> result{[]() {
+/// std::vector<double> prototype;
+/// std::mt19937 random_generator;
+/// for (int i = 0; i < 10; ++i) {
+/// double new_value = random_generator();
+/// prototype.push_back(new_value);
+/// }
+/// return prototype;
+/// }()};
+/// return result.access();
+/// }
+/// @endcode
+///
+/// Note in particular the `()` after the lambda. That causes it to be invoked.
//
// The above examples are repeated in the unit test; keep them in sync.
template <typename T>
diff --git a/wpimath/src/main/native/include/drake/math/discrete_algebraic_riccati_equation.h b/wpimath/src/main/native/include/drake/math/discrete_algebraic_riccati_equation.h
index e45cdc8..55b8442 100644
--- a/wpimath/src/main/native/include/drake/math/discrete_algebraic_riccati_equation.h
+++ b/wpimath/src/main/native/include/drake/math/discrete_algebraic_riccati_equation.h
@@ -4,29 +4,82 @@
#include <cstdlib>
#include <Eigen/Core>
+#include <wpi/SymbolExports.h>
namespace drake {
namespace math {
-/// Computes the unique stabilizing solution X to the discrete-time algebraic
-/// Riccati equation:
-///
-/// \f[
-/// A'XA - X - A'XB(B'XB+R)^{-1}B'XA + Q = 0
-/// \f]
-///
-/// @throws std::runtime_error if Q is not positive semi-definite.
-/// @throws std::runtime_error if R is not positive definite.
-///
-/// Based on the Schur Vector approach outlined in this paper:
-/// "On the Numerical Solution of the Discrete-Time Algebraic Riccati Equation"
-/// by Thrasyvoulos Pappas, Alan J. Laub, and Nils R. Sandell
-///
+/**
+Computes the unique stabilizing solution X to the discrete-time algebraic
+Riccati equation:
+
+AᵀXA − X − AᵀXB(BᵀXB + R)⁻¹BᵀXA + Q = 0
+
+@throws std::exception if Q is not positive semi-definite.
+@throws std::exception if R is not positive definite.
+
+Based on the Schur Vector approach outlined in this paper:
+"On the Numerical Solution of the Discrete-Time Algebraic Riccati Equation"
+by Thrasyvoulos Pappas, Alan J. Laub, and Nils R. Sandell
+*/
+WPILIB_DLLEXPORT
Eigen::MatrixXd DiscreteAlgebraicRiccatiEquation(
const Eigen::Ref<const Eigen::MatrixXd>& A,
const Eigen::Ref<const Eigen::MatrixXd>& B,
const Eigen::Ref<const Eigen::MatrixXd>& Q,
const Eigen::Ref<const Eigen::MatrixXd>& R);
+
+/**
+Computes the unique stabilizing solution X to the discrete-time algebraic
+Riccati equation:
+
+AᵀXA − X − (AᵀXB + N)(BᵀXB + R)⁻¹(BᵀXA + Nᵀ) + Q = 0
+
+This is equivalent to solving the original DARE:
+
+A₂ᵀXA₂ − X − A₂ᵀXB(BᵀXB + R)⁻¹BᵀXA₂ + Q₂ = 0
+
+where A₂ and Q₂ are a change of variables:
+
+A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ
+
+This overload of the DARE is useful for finding the control law uₖ that
+minimizes the following cost function subject to xₖ₊₁ = Axₖ + Buₖ.
+
+@verbatim
+ ∞ [xₖ]ᵀ[Q N][xₖ]
+J = Σ [uₖ] [Nᵀ R][uₖ] ΔT
+ k=0
+@endverbatim
+
+This is a more general form of the following. The linear-quadratic regulator
+is the feedback control law uₖ that minimizes the following cost function
+subject to xₖ₊₁ = Axₖ + Buₖ:
+
+@verbatim
+ ∞
+J = Σ (xₖᵀQxₖ + uₖᵀRuₖ) ΔT
+ k=0
+@endverbatim
+
+This can be refactored as:
+
+@verbatim
+ ∞ [xₖ]ᵀ[Q 0][xₖ]
+J = Σ [uₖ] [0 R][uₖ] ΔT
+ k=0
+@endverbatim
+
+@throws std::runtime_error if Q − NR⁻¹Nᵀ is not positive semi-definite.
+@throws std::runtime_error if R is not positive definite.
+*/
+WPILIB_DLLEXPORT
+Eigen::MatrixXd DiscreteAlgebraicRiccatiEquation(
+ const Eigen::Ref<const Eigen::MatrixXd>& A,
+ const Eigen::Ref<const Eigen::MatrixXd>& B,
+ const Eigen::Ref<const Eigen::MatrixXd>& Q,
+ const Eigen::Ref<const Eigen::MatrixXd>& R,
+ const Eigen::Ref<const Eigen::MatrixXd>& N);
} // namespace math
} // namespace drake
diff --git a/wpimath/src/main/native/include/frc/LinearFilter.h b/wpimath/src/main/native/include/frc/LinearFilter.h
deleted file mode 100644
index 1fe0edc..0000000
--- a/wpimath/src/main/native/include/frc/LinearFilter.h
+++ /dev/null
@@ -1,196 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2015-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-#pragma once
-
-#include <cassert>
-#include <cmath>
-#include <initializer_list>
-#include <vector>
-
-#include <wpi/ArrayRef.h>
-#include <wpi/circular_buffer.h>
-
-#include "units/time.h"
-#include "wpimath/MathShared.h"
-
-namespace frc {
-
-/**
- * This class implements a linear, digital filter. All types of FIR and IIR
- * filters are supported. Static factory methods are provided to create commonly
- * used types of filters.
- *
- * Filters are of the form:<br>
- * y[n] = (b0 * x[n] + b1 * x[n-1] + … + bP * x[n-P]) -
- * (a0 * y[n-1] + a2 * y[n-2] + … + aQ * y[n-Q])
- *
- * Where:<br>
- * y[n] is the output at time "n"<br>
- * x[n] is the input at time "n"<br>
- * y[n-1] is the output from the LAST time step ("n-1")<br>
- * x[n-1] is the input from the LAST time step ("n-1")<br>
- * b0 … bP are the "feedforward" (FIR) gains<br>
- * a0 … aQ are the "feedback" (IIR) gains<br>
- * IMPORTANT! Note the "-" sign in front of the feedback term! This is a common
- * convention in signal processing.
- *
- * What can linear filters do? Basically, they can filter, or diminish, the
- * effects of undesirable input frequencies. High frequencies, or rapid changes,
- * can be indicative of sensor noise or be otherwise undesirable. A "low pass"
- * filter smooths out the signal, reducing the impact of these high frequency
- * components. Likewise, a "high pass" filter gets rid of slow-moving signal
- * components, letting you detect large changes more easily.
- *
- * Example FRC applications of filters:
- * - Getting rid of noise from an analog sensor input (note: the roboRIO's FPGA
- * can do this faster in hardware)
- * - Smoothing out joystick input to prevent the wheels from slipping or the
- * robot from tipping
- * - Smoothing motor commands so that unnecessary strain isn't put on
- * electrical or mechanical components
- * - If you use clever gains, you can make a PID controller out of this class!
- *
- * For more on filters, we highly recommend the following articles:<br>
- * https://en.wikipedia.org/wiki/Linear_filter<br>
- * https://en.wikipedia.org/wiki/Iir_filter<br>
- * https://en.wikipedia.org/wiki/Fir_filter<br>
- *
- * Note 1: Calculate() should be called by the user on a known, regular period.
- * You can use a Notifier for this or do it "inline" with code in a
- * periodic function.
- *
- * Note 2: For ALL filters, gains are necessarily a function of frequency. If
- * you make a filter that works well for you at, say, 100Hz, you will most
- * definitely need to adjust the gains if you then want to run it at 200Hz!
- * Combining this with Note 1 - the impetus is on YOU as a developer to make
- * sure Calculate() gets called at the desired, constant frequency!
- */
-template <class T>
-class LinearFilter {
- public:
- /**
- * Create a linear FIR or IIR filter.
- *
- * @param ffGains The "feed forward" or FIR gains.
- * @param fbGains The "feed back" or IIR gains.
- */
- LinearFilter(wpi::ArrayRef<double> ffGains, wpi::ArrayRef<double> fbGains)
- : m_inputs(ffGains.size()),
- m_outputs(fbGains.size()),
- m_inputGains(ffGains),
- m_outputGains(fbGains) {
- static int instances = 0;
- instances++;
- wpi::math::MathSharedStore::ReportUsage(
- wpi::math::MathUsageId::kFilter_Linear, 1);
- }
-
- /**
- * Create a linear FIR or IIR filter.
- *
- * @param ffGains The "feed forward" or FIR gains.
- * @param fbGains The "feed back" or IIR gains.
- */
- LinearFilter(std::initializer_list<double> ffGains,
- std::initializer_list<double> fbGains)
- : LinearFilter(wpi::makeArrayRef(ffGains.begin(), ffGains.end()),
- wpi::makeArrayRef(fbGains.begin(), fbGains.end())) {}
-
- // Static methods to create commonly used filters
- /**
- * Creates a one-pole IIR low-pass filter of the form:<br>
- * y[n] = (1 - gain) * x[n] + gain * y[n-1]<br>
- * where gain = e<sup>-dt / T</sup>, T is the time constant in seconds
- *
- * This filter is stable for time constants greater than zero.
- *
- * @param timeConstant The discrete-time time constant in seconds.
- * @param period The period in seconds between samples taken by the
- * user.
- */
- static LinearFilter<T> SinglePoleIIR(double timeConstant,
- units::second_t period) {
- double gain = std::exp(-period.to<double>() / timeConstant);
- return LinearFilter(1.0 - gain, -gain);
- }
-
- /**
- * Creates a first-order high-pass filter of the form:<br>
- * y[n] = gain * x[n] + (-gain) * x[n-1] + gain * y[n-1]<br>
- * where gain = e<sup>-dt / T</sup>, T is the time constant in seconds
- *
- * This filter is stable for time constants greater than zero.
- *
- * @param timeConstant The discrete-time time constant in seconds.
- * @param period The period in seconds between samples taken by the
- * user.
- */
- static LinearFilter<T> HighPass(double timeConstant, units::second_t period) {
- double gain = std::exp(-period.to<double>() / timeConstant);
- return LinearFilter({gain, -gain}, {-gain});
- }
-
- /**
- * Creates a K-tap FIR moving average filter of the form:<br>
- * y[n] = 1/k * (x[k] + x[k-1] + … + x[0])
- *
- * This filter is always stable.
- *
- * @param taps The number of samples to average over. Higher = smoother but
- * slower
- */
- static LinearFilter<T> MovingAverage(int taps) {
- assert(taps > 0);
-
- std::vector<double> gains(taps, 1.0 / taps);
- return LinearFilter(gains, {});
- }
-
- /**
- * Reset the filter state.
- */
- void Reset() {
- m_inputs.reset();
- m_outputs.reset();
- }
-
- /**
- * Calculates the next value of the filter.
- *
- * @param input Current input value.
- *
- * @return The filtered value at this step
- */
- T Calculate(T input) {
- T retVal = T(0.0);
-
- // Rotate the inputs
- m_inputs.push_front(input);
-
- // Calculate the new value
- for (size_t i = 0; i < m_inputGains.size(); i++) {
- retVal += m_inputs[i] * m_inputGains[i];
- }
- for (size_t i = 0; i < m_outputGains.size(); i++) {
- retVal -= m_outputs[i] * m_outputGains[i];
- }
-
- // Rotate the outputs
- m_outputs.push_front(retVal);
-
- return retVal;
- }
-
- private:
- wpi::circular_buffer<T> m_inputs;
- wpi::circular_buffer<T> m_outputs;
- std::vector<double> m_inputGains;
- std::vector<double> m_outputGains;
-};
-
-} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/MathUtil.h b/wpimath/src/main/native/include/frc/MathUtil.h
new file mode 100644
index 0000000..54a77af
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/MathUtil.h
@@ -0,0 +1,58 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <wpi/SymbolExports.h>
+#include <wpi/numbers>
+
+#include "units/angle.h"
+
+namespace frc {
+
+/**
+ * Returns 0.0 if the given value is within the specified range around zero.
+ * The remaining range between the deadband and 1.0 is scaled from 0.0 to 1.0.
+ *
+ * @param value Value to clip.
+ * @param deadband Range around zero.
+ */
+WPILIB_DLLEXPORT
+double ApplyDeadband(double value, double deadband);
+
+/**
+ * Returns modulus of input.
+ *
+ * @param input Input value to wrap.
+ * @param minimumInput The minimum value expected from the input.
+ * @param maximumInput The maximum value expected from the input.
+ */
+template <typename T>
+constexpr T InputModulus(T input, T minimumInput, T maximumInput) {
+ T modulus = maximumInput - minimumInput;
+
+ // Wrap input if it's above the maximum input
+ int numMax = (input - minimumInput) / modulus;
+ input -= numMax * modulus;
+
+ // Wrap input if it's below the minimum input
+ int numMin = (input - maximumInput) / modulus;
+ input -= numMin * modulus;
+
+ return input;
+}
+
+/**
+ * Wraps an angle to the range -pi to pi radians (-180 to 180 degrees).
+ *
+ * @param angle Angle to wrap.
+ */
+WPILIB_DLLEXPORT
+constexpr units::radian_t AngleModulus(units::radian_t angle) {
+ return InputModulus<units::radian_t>(angle,
+ units::radian_t{-wpi::numbers::pi},
+ units::radian_t{wpi::numbers::pi});
+}
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/StateSpaceUtil.h b/wpimath/src/main/native/include/frc/StateSpaceUtil.h
index b461005..730c4b9 100644
--- a/wpimath/src/main/native/include/frc/StateSpaceUtil.h
+++ b/wpimath/src/main/native/include/frc/StateSpaceUtil.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -12,9 +9,12 @@
#include <random>
#include <type_traits>
+#include <wpi/SymbolExports.h>
+#include <wpi/deprecated.h>
+
#include "Eigen/Core"
+#include "Eigen/Eigenvalues"
#include "Eigen/QR"
-#include "Eigen/src/Eigenvalues/EigenSolver.h"
#include "frc/geometry/Pose2d.h"
namespace frc {
@@ -61,7 +61,7 @@
template <int States, int Inputs>
bool IsStabilizableImpl(const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B) {
- Eigen::EigenSolver<Eigen::Matrix<double, States, States>> es(A);
+ Eigen::EigenSolver<Eigen::Matrix<double, States, States>> es{A};
for (int i = 0; i < States; ++i) {
if (es.eigenvalues()[i].real() * es.eigenvalues()[i].real() +
@@ -78,7 +78,7 @@
Eigen::ColPivHouseholderQR<
Eigen::Matrix<std::complex<double>, States, States + Inputs>>
- qr(E);
+ qr{E};
if (qr.rank() < States) {
return false;
}
@@ -95,10 +95,13 @@
*
* @param elems An array of elements in the matrix.
* @return A matrix containing the given elements.
+ * @deprecated Use Eigen::Matrix or Eigen::Vector initializer list constructor.
*/
template <int Rows, int Cols, typename... Ts,
typename =
std::enable_if_t<std::conjunction_v<std::is_same<double, Ts>...>>>
+WPI_DEPRECATED(
+ "Use Eigen::Matrix or Eigen::Vector initializer list constructor")
Eigen::Matrix<double, Rows, Cols> MakeMatrix(Ts... elems) {
static_assert(
sizeof...(elems) == Rows * Cols,
@@ -213,12 +216,12 @@
* @return White noise vector.
*/
template <int N>
-Eigen::Matrix<double, N, 1> MakeWhiteNoiseVector(
+Eigen::Vector<double, N> MakeWhiteNoiseVector(
const std::array<double, N>& stdDevs) {
std::random_device rd;
std::mt19937 gen{rd()};
- Eigen::Matrix<double, N, 1> result;
+ Eigen::Vector<double, N> result;
for (int i = 0; i < N; ++i) {
// Passing a standard deviation of 0.0 to std::normal_distribution is
// undefined behavior
@@ -233,12 +236,34 @@
}
/**
+ * Converts a Pose2d into a vector of [x, y, theta].
+ *
+ * @param pose The pose that is being represented.
+ *
+ * @return The vector.
+ */
+WPILIB_DLLEXPORT
+Eigen::Vector<double, 3> PoseTo3dVector(const Pose2d& pose);
+
+/**
+ * Converts a Pose2d into a vector of [x, y, std::cos(theta), std::sin(theta)].
+ *
+ * @param pose The pose that is being represented.
+ *
+ * @return The vector.
+ */
+WPILIB_DLLEXPORT
+Eigen::Vector<double, 4> PoseTo4dVector(const Pose2d& pose);
+
+/**
* Returns true if (A, B) is a stabilizable pair.
*
- * (A,B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
+ * (A, B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
* any, have absolute values less than one, where an eigenvalue is
- * uncontrollable if rank(lambda * I - A, B) < n where n is number of states.
+ * uncontrollable if rank(λI - A, B) < n where n is the number of states.
*
+ * @tparam States The number of states.
+ * @tparam Inputs The number of inputs.
* @param A System matrix.
* @param B Input matrix.
*/
@@ -248,17 +273,36 @@
return detail::IsStabilizableImpl<States, Inputs>(A, B);
}
-// Template specializations are used here to make common state-input pairs
-// compile faster.
-template <>
-bool IsStabilizable<1, 1>(const Eigen::Matrix<double, 1, 1>& A,
- const Eigen::Matrix<double, 1, 1>& B);
+/**
+ * Returns true if (A, C) is a detectable pair.
+ *
+ * (A, C) is detectable if and only if the unobservable eigenvalues of A, if
+ * any, have absolute values less than one, where an eigenvalue is unobservable
+ * if rank(λI - A; C) < n where n is the number of states.
+ *
+ * @tparam States The number of states.
+ * @tparam Outputs The number of outputs.
+ * @param A System matrix.
+ * @param C Output matrix.
+ */
+template <int States, int Outputs>
+bool IsDetectable(const Eigen::Matrix<double, States, States>& A,
+ const Eigen::Matrix<double, Outputs, States>& C) {
+ return detail::IsStabilizableImpl<States, Outputs>(A.transpose(),
+ C.transpose());
+}
// Template specializations are used here to make common state-input pairs
// compile faster.
template <>
-bool IsStabilizable<2, 1>(const Eigen::Matrix<double, 2, 2>& A,
- const Eigen::Matrix<double, 2, 1>& B);
+WPILIB_DLLEXPORT bool IsStabilizable<1, 1>(
+ const Eigen::Matrix<double, 1, 1>& A, const Eigen::Matrix<double, 1, 1>& B);
+
+// Template specializations are used here to make common state-input pairs
+// compile faster.
+template <>
+WPILIB_DLLEXPORT bool IsStabilizable<2, 1>(
+ const Eigen::Matrix<double, 2, 2>& A, const Eigen::Matrix<double, 2, 1>& B);
/**
* Converts a Pose2d into a vector of [x, y, theta].
@@ -267,20 +311,24 @@
*
* @return The vector.
*/
-Eigen::Matrix<double, 3, 1> PoseToVector(const Pose2d& pose);
+WPILIB_DLLEXPORT
+Eigen::Vector<double, 3> PoseToVector(const Pose2d& pose);
/**
* Clamps input vector between system's minimum and maximum allowable input.
*
+ * @tparam Inputs The number of inputs.
* @param u Input vector to clamp.
+ * @param umin The minimum input magnitude.
+ * @param umax The maximum input magnitude.
* @return Clamped input vector.
*/
template <int Inputs>
-Eigen::Matrix<double, Inputs, 1> ClampInputMaxMagnitude(
- const Eigen::Matrix<double, Inputs, 1>& u,
- const Eigen::Matrix<double, Inputs, 1>& umin,
- const Eigen::Matrix<double, Inputs, 1>& umax) {
- Eigen::Matrix<double, Inputs, 1> result;
+Eigen::Vector<double, Inputs> ClampInputMaxMagnitude(
+ const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Inputs>& umin,
+ const Eigen::Vector<double, Inputs>& umax) {
+ Eigen::Vector<double, Inputs> result;
for (int i = 0; i < Inputs; ++i) {
result(i) = std::clamp(u(i), umin(i), umax(i));
}
@@ -291,14 +339,14 @@
* Normalize all inputs if any excedes the maximum magnitude. Useful for systems
* such as differential drivetrains.
*
+ * @tparam Inputs The number of inputs.
* @param u The input vector.
* @param maxMagnitude The maximum magnitude any input can have.
- * @param <I> The number of inputs.
* @return The normalizedInput
*/
template <int Inputs>
-Eigen::Matrix<double, Inputs, 1> NormalizeInputVector(
- const Eigen::Matrix<double, Inputs, 1>& u, double maxMagnitude) {
+Eigen::Vector<double, Inputs> NormalizeInputVector(
+ const Eigen::Vector<double, Inputs>& u, double maxMagnitude) {
double maxValue = u.template lpNorm<Eigen::Infinity>();
if (maxValue > maxMagnitude) {
diff --git a/wpimath/src/main/native/include/frc/controller/ArmFeedforward.h b/wpimath/src/main/native/include/frc/controller/ArmFeedforward.h
index 14df187..eb7cb76 100644
--- a/wpimath/src/main/native/include/frc/controller/ArmFeedforward.h
+++ b/wpimath/src/main/native/include/frc/controller/ArmFeedforward.h
@@ -1,13 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <wpi/MathExtras.h>
+#include <wpi/SymbolExports.h>
#include "units/angle.h"
#include "units/angular_velocity.h"
@@ -19,7 +17,7 @@
* A helper class that computes feedforward outputs for a simple arm (modeled as
* a motor acting against the force of gravity on a beam suspended at an angle).
*/
-class ArmFeedforward {
+class WPILIB_DLLEXPORT ArmFeedforward {
public:
using Angle = units::radians;
using Velocity = units::radians_per_second;
diff --git a/wpimath/src/main/native/include/frc/controller/ControlAffinePlantInversionFeedforward.h b/wpimath/src/main/native/include/frc/controller/ControlAffinePlantInversionFeedforward.h
index 134ed97..656f767 100644
--- a/wpimath/src/main/native/include/frc/controller/ControlAffinePlantInversionFeedforward.h
+++ b/wpimath/src/main/native/include/frc/controller/ControlAffinePlantInversionFeedforward.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -35,6 +32,9 @@
*
* For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ *
+ * @tparam States The number of states.
+ * @tparam Inputs the number of inputs.
*/
template <int States, int Inputs>
class ControlAffinePlantInversionFeedforward {
@@ -50,18 +50,17 @@
* @param dt The timestep between calls of calculate().
*/
ControlAffinePlantInversionFeedforward(
- std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<
+ Eigen::Vector<double, States>(const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
f,
units::second_t dt)
: m_dt(dt), m_f(f) {
m_B = NumericalJacobianU<States, States, Inputs>(
- f, Eigen::Matrix<double, States, 1>::Zero(),
- Eigen::Matrix<double, Inputs, 1>::Zero());
+ f, Eigen::Vector<double, States>::Zero(),
+ Eigen::Vector<double, Inputs>::Zero());
- m_r.setZero();
- Reset(m_r);
+ Reset();
}
/**
@@ -74,17 +73,16 @@
* @param dt The timestep between calls of calculate().
*/
ControlAffinePlantInversionFeedforward(
- std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&)>
+ std::function<
+ Eigen::Vector<double, States>(const Eigen::Vector<double, States>&)>
f,
const Eigen::Matrix<double, States, Inputs>& B, units::second_t dt)
: m_B(B), m_dt(dt) {
- m_f = [=](const Eigen::Matrix<double, States, 1>& x,
- const Eigen::Matrix<double, Inputs, 1>& u)
- -> Eigen::Matrix<double, States, 1> { return f(x); };
+ m_f = [=](const Eigen::Vector<double, States>& x,
+ const Eigen::Vector<double, Inputs>& u)
+ -> Eigen::Vector<double, States> { return f(x); };
- m_r.setZero();
- Reset(m_r);
+ Reset();
}
ControlAffinePlantInversionFeedforward(
@@ -97,12 +95,12 @@
*
* @return The calculated feedforward.
*/
- const Eigen::Matrix<double, Inputs, 1>& Uff() const { return m_uff; }
+ const Eigen::Vector<double, Inputs>& Uff() const { return m_uff; }
/**
* Returns an element of the previously calculated feedforward.
*
- * @param row Row of uff.
+ * @param i Row of uff.
*
* @return The row of the calculated feedforward.
*/
@@ -113,7 +111,7 @@
*
* @return The current reference vector.
*/
- const Eigen::Matrix<double, States, 1>& R() const { return m_r; }
+ const Eigen::Vector<double, States>& R() const { return m_r; }
/**
* Returns an element of the reference vector r.
@@ -129,7 +127,7 @@
*
* @param initialState The initial state vector.
*/
- void Reset(const Eigen::Matrix<double, States, 1>& initialState) {
+ void Reset(const Eigen::Vector<double, States>& initialState) {
m_r = initialState;
m_uff.setZero();
}
@@ -147,16 +145,16 @@
* future reference. This uses the internally stored "current"
* reference.
*
- * If this method is used the initial state of the system is the one
- * set using Reset(const Eigen::Matrix<double, States, 1>&).
- * If the initial state is not set it defaults to a zero vector.
+ * If this method is used the initial state of the system is the one set using
+ * Reset(const Eigen::Vector<double, States>&). If the initial state is not
+ * set it defaults to a zero vector.
*
* @param nextR The reference state of the future timestep (k + dt).
*
* @return The calculated feedforward.
*/
- Eigen::Matrix<double, Inputs, 1> Calculate(
- const Eigen::Matrix<double, States, 1>& nextR) {
+ Eigen::Vector<double, Inputs> Calculate(
+ const Eigen::Vector<double, States>& nextR) {
return Calculate(m_r, nextR);
}
@@ -168,13 +166,13 @@
*
* @return The calculated feedforward.
*/
- Eigen::Matrix<double, Inputs, 1> Calculate(
- const Eigen::Matrix<double, States, 1>& r,
- const Eigen::Matrix<double, States, 1>& nextR) {
- Eigen::Matrix<double, States, 1> rDot = (nextR - r) / m_dt.to<double>();
+ Eigen::Vector<double, Inputs> Calculate(
+ const Eigen::Vector<double, States>& r,
+ const Eigen::Vector<double, States>& nextR) {
+ Eigen::Vector<double, States> rDot = (nextR - r) / m_dt.value();
m_uff = m_B.householderQr().solve(
- rDot - m_f(r, Eigen::Matrix<double, Inputs, 1>::Zero()));
+ rDot - m_f(r, Eigen::Vector<double, Inputs>::Zero()));
m_r = nextR;
return m_uff;
@@ -188,16 +186,16 @@
/**
* The model dynamics.
*/
- std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
m_f;
// Current reference
- Eigen::Matrix<double, States, 1> m_r;
+ Eigen::Vector<double, States> m_r;
// Computed feedforward
- Eigen::Matrix<double, Inputs, 1> m_uff;
+ Eigen::Vector<double, Inputs> m_uff;
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/controller/ElevatorFeedforward.h b/wpimath/src/main/native/include/frc/controller/ElevatorFeedforward.h
index b82d960..269a7e6 100644
--- a/wpimath/src/main/native/include/frc/controller/ElevatorFeedforward.h
+++ b/wpimath/src/main/native/include/frc/controller/ElevatorFeedforward.h
@@ -1,14 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <wpi/MathExtras.h>
+#include "units/time.h"
#include "units/voltage.h"
namespace frc {
diff --git a/wpimath/src/main/native/include/frc/controller/HolonomicDriveController.h b/wpimath/src/main/native/include/frc/controller/HolonomicDriveController.h
new file mode 100644
index 0000000..398a872
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/controller/HolonomicDriveController.h
@@ -0,0 +1,112 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <wpi/SymbolExports.h>
+
+#include "frc/controller/PIDController.h"
+#include "frc/controller/ProfiledPIDController.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/geometry/Rotation2d.h"
+#include "frc/kinematics/ChassisSpeeds.h"
+#include "frc/trajectory/Trajectory.h"
+#include "units/angle.h"
+#include "units/velocity.h"
+
+namespace frc {
+/**
+ * This holonomic drive controller can be used to follow trajectories using a
+ * holonomic drivetrain (i.e. swerve or mecanum). Holonomic trajectory following
+ * is a much simpler problem to solve compared to skid-steer style drivetrains
+ * because it is possible to individually control forward, sideways, and angular
+ * velocity.
+ *
+ * The holonomic drive controller takes in one PID controller for each
+ * direction, forward and sideways, and one profiled PID controller for the
+ * angular direction. Because the heading dynamics are decoupled from
+ * translations, users can specify a custom heading that the drivetrain should
+ * point toward. This heading reference is profiled for smoothness.
+ */
+class WPILIB_DLLEXPORT HolonomicDriveController {
+ public:
+ /**
+ * Constructs a holonomic drive controller.
+ *
+ * @param xController A PID Controller to respond to error in the
+ * field-relative x direction.
+ * @param yController A PID Controller to respond to error in the
+ * field-relative y direction.
+ * @param thetaController A profiled PID controller to respond to error in
+ * angle.
+ */
+ HolonomicDriveController(
+ frc2::PIDController xController, frc2::PIDController yController,
+ ProfiledPIDController<units::radian> thetaController);
+
+ /**
+ * Returns true if the pose error is within tolerance of the reference.
+ */
+ bool AtReference() const;
+
+ /**
+ * Sets the pose error which is considered tolerable for use with
+ * AtReference().
+ *
+ * @param tolerance Pose error which is tolerable.
+ */
+ void SetTolerance(const Pose2d& tolerance);
+
+ /**
+ * Returns the next output of the holonomic drive controller.
+ *
+ * The reference pose, linear velocity, and angular velocity should come from
+ * a drivetrain trajectory.
+ *
+ * @param currentPose The current pose.
+ * @param poseRef The desired pose.
+ * @param linearVelocityRef The desired linear velocity.
+ * @param angleRef The desired ending angle.
+ */
+ ChassisSpeeds Calculate(const Pose2d& currentPose, const Pose2d& poseRef,
+ units::meters_per_second_t linearVelocityRef,
+ const Rotation2d& angleRef);
+
+ /**
+ * Returns the next output of the holonomic drive controller.
+ *
+ * The reference pose, linear velocity, and angular velocity should come from
+ * a drivetrain trajectory.
+ *
+ * @param currentPose The current pose.
+ * @param desiredState The desired pose, linear velocity, and angular velocity
+ * from a trajectory.
+ * @param angleRef The desired ending angle.
+ */
+ ChassisSpeeds Calculate(const Pose2d& currentPose,
+ const Trajectory::State& desiredState,
+ const Rotation2d& angleRef);
+
+ /**
+ * Enables and disables the controller for troubleshooting purposes. When
+ * Calculate() is called on a disabled controller, only feedforward values
+ * are returned.
+ *
+ * @param enabled If the controller is enabled or not.
+ */
+ void SetEnabled(bool enabled);
+
+ private:
+ Pose2d m_poseError;
+ Rotation2d m_rotationError;
+ Pose2d m_poseTolerance;
+ bool m_enabled = true;
+
+ frc2::PIDController m_xController;
+ frc2::PIDController m_yController;
+ ProfiledPIDController<units::radian> m_thetaController;
+
+ bool m_firstRun = true;
+};
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/controller/LinearPlantInversionFeedforward.h b/wpimath/src/main/native/include/frc/controller/LinearPlantInversionFeedforward.h
index ea86d90..519368d 100644
--- a/wpimath/src/main/native/include/frc/controller/LinearPlantInversionFeedforward.h
+++ b/wpimath/src/main/native/include/frc/controller/LinearPlantInversionFeedforward.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -27,6 +24,9 @@
*
* For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ *
+ * @tparam States The number of states.
+ * @tparam Inputs The number of inputs.
*/
template <int States, int Inputs>
class LinearPlantInversionFeedforward {
@@ -34,8 +34,9 @@
/**
* Constructs a feedforward with the given plant.
*
- * @param plant The plant being controlled.
- * @param dtSeconds Discretization timestep.
+ * @tparam Outputs The number of outputs.
+ * @param plant The plant being controlled.
+ * @param dt Discretization timestep.
*/
template <int Outputs>
LinearPlantInversionFeedforward(
@@ -45,18 +46,16 @@
/**
* Constructs a feedforward with the given coefficients.
*
- * @param A Continuous system matrix of the plant being controlled.
- * @param B Continuous input matrix of the plant being controlled.
- * @param dtSeconds Discretization timestep.
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param dt Discretization timestep.
*/
LinearPlantInversionFeedforward(
const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B, units::second_t dt)
: m_dt(dt) {
DiscretizeAB<States, Inputs>(A, B, dt, &m_A, &m_B);
-
- m_r.setZero();
- Reset(m_r);
+ Reset();
}
/**
@@ -64,12 +63,12 @@
*
* @return The calculated feedforward.
*/
- const Eigen::Matrix<double, Inputs, 1>& Uff() const { return m_uff; }
+ const Eigen::Vector<double, Inputs>& Uff() const { return m_uff; }
/**
* Returns an element of the previously calculated feedforward.
*
- * @param row Row of uff.
+ * @param i Row of uff.
*
* @return The row of the calculated feedforward.
*/
@@ -80,7 +79,7 @@
*
* @return The current reference vector.
*/
- const Eigen::Matrix<double, States, 1>& R() const { return m_r; }
+ const Eigen::Vector<double, States>& R() const { return m_r; }
/**
* Returns an element of the reference vector r.
@@ -96,7 +95,7 @@
*
* @param initialState The initial state vector.
*/
- void Reset(const Eigen::Matrix<double, States, 1>& initialState) {
+ void Reset(const Eigen::Vector<double, States>& initialState) {
m_r = initialState;
m_uff.setZero();
}
@@ -114,17 +113,17 @@
* future reference. This uses the internally stored "current"
* reference.
*
- * If this method is used the initial state of the system is the one
- * set using Reset(const Eigen::Matrix<double, States, 1>&).
- * If the initial state is not set it defaults to a zero vector.
+ * If this method is used the initial state of the system is the one set using
+ * Reset(const Eigen::Vector<double, States>&). If the initial state is not
+ * set it defaults to a zero vector.
*
* @param nextR The reference state of the future timestep (k + dt).
*
* @return The calculated feedforward.
*/
- Eigen::Matrix<double, Inputs, 1> Calculate(
- const Eigen::Matrix<double, States, 1>& nextR) {
- return Calculate(m_r, nextR);
+ Eigen::Vector<double, Inputs> Calculate(
+ const Eigen::Vector<double, States>& nextR) {
+ return Calculate(m_r, nextR); // NOLINT
}
/**
@@ -135,9 +134,9 @@
*
* @return The calculated feedforward.
*/
- Eigen::Matrix<double, Inputs, 1> Calculate(
- const Eigen::Matrix<double, States, 1>& r,
- const Eigen::Matrix<double, States, 1>& nextR) {
+ Eigen::Vector<double, Inputs> Calculate(
+ const Eigen::Vector<double, States>& r,
+ const Eigen::Vector<double, States>& nextR) {
m_uff = m_B.householderQr().solve(nextR - (m_A * r));
m_r = nextR;
return m_uff;
@@ -150,10 +149,10 @@
units::second_t m_dt;
// Current reference
- Eigen::Matrix<double, States, 1> m_r;
+ Eigen::Vector<double, States> m_r;
// Computed feedforward
- Eigen::Matrix<double, Inputs, 1> m_uff;
+ Eigen::Vector<double, Inputs> m_uff;
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h b/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h
index f448957..c934e0a 100644
--- a/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h
+++ b/wpimath/src/main/native/include/frc/controller/LinearQuadraticRegulator.h
@@ -1,21 +1,26 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
+#include <frc/fmt/Eigen.h>
+#include <string>
+
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
+
+#include "Eigen/Cholesky"
#include "Eigen/Core"
-#include "Eigen/src/Cholesky/LLT.h"
+#include "Eigen/Eigenvalues"
#include "drake/math/discrete_algebraic_riccati_equation.h"
#include "frc/StateSpaceUtil.h"
#include "frc/system/Discretization.h"
#include "frc/system/LinearSystem.h"
#include "units/time.h"
+#include "unsupported/Eigen/MatrixFunctions"
+#include "wpimath/MathShared.h"
namespace frc {
namespace detail {
@@ -27,6 +32,9 @@
*
* For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ *
+ * @tparam States Number of states.
+ * @tparam Inputs Number of inputs.
*/
template <int States, int Inputs>
class LinearQuadraticRegulatorImpl {
@@ -42,8 +50,8 @@
template <int Outputs>
LinearQuadraticRegulatorImpl(
const LinearSystem<States, Inputs, Outputs>& plant,
- const std::array<double, States>& Qelems,
- const std::array<double, Inputs>& Relems, units::second_t dt)
+ const wpi::array<double, States>& Qelems,
+ const wpi::array<double, Inputs>& Relems, units::second_t dt)
: LinearQuadraticRegulatorImpl(plant.A(), plant.B(), Qelems, Relems, dt) {
}
@@ -58,8 +66,8 @@
*/
LinearQuadraticRegulatorImpl(const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B,
- const std::array<double, States>& Qelems,
- const std::array<double, Inputs>& Relems,
+ const wpi::array<double, States>& Qelems,
+ const wpi::array<double, Inputs>& Relems,
units::second_t dt)
: LinearQuadraticRegulatorImpl(A, B, MakeCostMatrix(Qelems),
MakeCostMatrix(Relems), dt) {}
@@ -67,11 +75,11 @@
/**
* Constructs a controller with the given coefficients and plant.
*
- * @param A Continuous system matrix of the plant being controlled.
- * @param B Continuous input matrix of the plant being controlled.
- * @param Q The state cost matrix.
- * @param R The input cost matrix.
- * @param dt Discretization timestep.
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param Q The state cost matrix.
+ * @param R The input cost matrix.
+ * @param dt Discretization timestep.
*/
LinearQuadraticRegulatorImpl(const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B,
@@ -82,11 +90,54 @@
Eigen::Matrix<double, States, Inputs> discB;
DiscretizeAB<States, Inputs>(A, B, dt, &discA, &discB);
+ if (!IsStabilizable<States, Inputs>(discA, discB)) {
+ std::string msg = fmt::format(
+ "The system passed to the LQR is uncontrollable!\n\nA =\n{}\nB "
+ "=\n{}\n",
+ discA, discB);
+
+ wpi::math::MathSharedStore::ReportError(msg);
+ throw std::invalid_argument(msg);
+ }
+
Eigen::Matrix<double, States, States> S =
drake::math::DiscreteAlgebraicRiccatiEquation(discA, discB, Q, R);
- Eigen::Matrix<double, Inputs, Inputs> tmp =
- discB.transpose() * S * discB + R;
- m_K = tmp.llt().solve(discB.transpose() * S * discA);
+
+ // K = (BᵀSB + R)⁻¹BᵀSA
+ m_K = (discB.transpose() * S * discB + R)
+ .llt()
+ .solve(discB.transpose() * S * discA);
+
+ Reset();
+ }
+
+ /**
+ * Constructs a controller with the given coefficients and plant.
+ *
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param Q The state cost matrix.
+ * @param R The input cost matrix.
+ * @param N The state-input cross-term cost matrix.
+ * @param dt Discretization timestep.
+ */
+ LinearQuadraticRegulatorImpl(const Eigen::Matrix<double, States, States>& A,
+ const Eigen::Matrix<double, States, Inputs>& B,
+ const Eigen::Matrix<double, States, States>& Q,
+ const Eigen::Matrix<double, Inputs, Inputs>& R,
+ const Eigen::Matrix<double, States, Inputs>& N,
+ units::second_t dt) {
+ Eigen::Matrix<double, States, States> discA;
+ Eigen::Matrix<double, States, Inputs> discB;
+ DiscretizeAB<States, Inputs>(A, B, dt, &discA, &discB);
+
+ Eigen::Matrix<double, States, States> S =
+ drake::math::DiscreteAlgebraicRiccatiEquation(discA, discB, Q, R, N);
+
+ // K = (BᵀSB + R)⁻¹(BᵀSA + Nᵀ)
+ m_K = (B.transpose() * S * B + R)
+ .llt()
+ .solve(discB.transpose() * S * discA + N.transpose());
Reset();
}
@@ -113,7 +164,7 @@
*
* @return The reference vector.
*/
- const Eigen::Matrix<double, States, 1>& R() const { return m_r; }
+ const Eigen::Vector<double, States>& R() const { return m_r; }
/**
* Returns an element of the reference vector r.
@@ -129,7 +180,7 @@
*
* @return The control input.
*/
- const Eigen::Matrix<double, Inputs, 1>& U() const { return m_u; }
+ const Eigen::Vector<double, Inputs>& U() const { return m_u; }
/**
* Returns an element of the control input vector u.
@@ -153,8 +204,8 @@
*
* @param x The current state x.
*/
- Eigen::Matrix<double, Inputs, 1> Calculate(
- const Eigen::Matrix<double, States, 1>& x) {
+ Eigen::Vector<double, Inputs> Calculate(
+ const Eigen::Vector<double, States>& x) {
m_u = m_K * (m_r - x);
return m_u;
}
@@ -165,19 +216,45 @@
* @param x The current state x.
* @param nextR The next reference vector r.
*/
- Eigen::Matrix<double, Inputs, 1> Calculate(
- const Eigen::Matrix<double, States, 1>& x,
- const Eigen::Matrix<double, States, 1>& nextR) {
+ Eigen::Vector<double, Inputs> Calculate(
+ const Eigen::Vector<double, States>& x,
+ const Eigen::Vector<double, States>& nextR) {
m_r = nextR;
return Calculate(x);
}
+ /**
+ * Adjusts LQR controller gain to compensate for a pure time delay in the
+ * input.
+ *
+ * Linear-Quadratic regulator controller gains tend to be aggressive. If
+ * sensor measurements are time-delayed too long, the LQR may be unstable.
+ * However, if we know the amount of delay, we can compute the control based
+ * on where the system will be after the time delay.
+ *
+ * See https://file.tavsys.net/control/controls-engineering-in-frc.pdf
+ * appendix C.4 for a derivation.
+ *
+ * @param plant The plant being controlled.
+ * @param dt Discretization timestep.
+ * @param inputDelay Input time delay.
+ */
+ template <int Outputs>
+ void LatencyCompensate(const LinearSystem<States, Inputs, Outputs>& plant,
+ units::second_t dt, units::second_t inputDelay) {
+ Eigen::Matrix<double, States, States> discA;
+ Eigen::Matrix<double, States, Inputs> discB;
+ DiscretizeAB<States, Inputs>(plant.A(), plant.B(), dt, &discA, &discB);
+
+ m_K = m_K * (discA - discB * m_K).pow(inputDelay / dt);
+ }
+
private:
// Current reference
- Eigen::Matrix<double, States, 1> m_r;
+ Eigen::Vector<double, States> m_r;
// Computed controller output
- Eigen::Matrix<double, Inputs, 1> m_u;
+ Eigen::Vector<double, Inputs> m_u;
// Controller gain
Eigen::Matrix<double, Inputs, States> m_K;
@@ -192,15 +269,16 @@
/**
* Constructs a controller with the given coefficients and plant.
*
- * @param system The plant being controlled.
+ * @tparam Outputs The number of outputs.
+ * @param plant The plant being controlled.
* @param Qelems The maximum desired error tolerance for each state.
* @param Relems The maximum desired control effort for each input.
* @param dt Discretization timestep.
*/
template <int Outputs>
LinearQuadraticRegulator(const LinearSystem<States, Inputs, Outputs>& plant,
- const std::array<double, States>& Qelems,
- const std::array<double, Inputs>& Relems,
+ const wpi::array<double, States>& Qelems,
+ const wpi::array<double, Inputs>& Relems,
units::second_t dt)
: LinearQuadraticRegulator(plant.A(), plant.B(), Qelems, Relems, dt) {}
@@ -215,8 +293,8 @@
*/
LinearQuadraticRegulator(const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B,
- const std::array<double, States>& Qelems,
- const std::array<double, Inputs>& Relems,
+ const wpi::array<double, States>& Qelems,
+ const wpi::array<double, Inputs>& Relems,
units::second_t dt)
: LinearQuadraticRegulator(A, B, MakeCostMatrix(Qelems),
MakeCostMatrix(Relems), dt) {}
@@ -224,11 +302,11 @@
/**
* Constructs a controller with the given coefficients and plant.
*
- * @param A Continuous system matrix of the plant being controlled.
- * @param B Continuous input matrix of the plant being controlled.
- * @param Q The state cost matrix.
- * @param R The input cost matrix.
- * @param dt Discretization timestep.
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param Q The state cost matrix.
+ * @param R The input cost matrix.
+ * @param dt Discretization timestep.
*/
LinearQuadraticRegulator(const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B,
@@ -237,6 +315,25 @@
units::second_t dt)
: detail::LinearQuadraticRegulatorImpl<States, Inputs>{A, B, Q, R, dt} {}
+ /**
+ * Constructs a controller with the given coefficients and plant.
+ *
+ * @param A Continuous system matrix of the plant being controlled.
+ * @param B Continuous input matrix of the plant being controlled.
+ * @param Q The state cost matrix.
+ * @param R The input cost matrix.
+ * @param N The state-input cross-term cost matrix.
+ * @param dt Discretization timestep.
+ */
+ LinearQuadraticRegulator(const Eigen::Matrix<double, States, States>& A,
+ const Eigen::Matrix<double, States, Inputs>& B,
+ const Eigen::Matrix<double, States, States>& Q,
+ const Eigen::Matrix<double, Inputs, Inputs>& R,
+ const Eigen::Matrix<double, States, Inputs>& N,
+ units::second_t dt)
+ : detail::LinearQuadraticRegulatorImpl<States, Inputs>{A, B, Q,
+ R, N, dt} {}
+
LinearQuadraticRegulator(LinearQuadraticRegulator&&) = default;
LinearQuadraticRegulator& operator=(LinearQuadraticRegulator&&) = default;
};
@@ -244,20 +341,20 @@
// Template specializations are used here to make common state-input pairs
// compile faster.
template <>
-class LinearQuadraticRegulator<1, 1>
+class WPILIB_DLLEXPORT LinearQuadraticRegulator<1, 1>
: public detail::LinearQuadraticRegulatorImpl<1, 1> {
public:
template <int Outputs>
LinearQuadraticRegulator(const LinearSystem<1, 1, Outputs>& plant,
- const std::array<double, 1>& Qelems,
- const std::array<double, 1>& Relems,
+ const wpi::array<double, 1>& Qelems,
+ const wpi::array<double, 1>& Relems,
units::second_t dt)
: LinearQuadraticRegulator(plant.A(), plant.B(), Qelems, Relems, dt) {}
LinearQuadraticRegulator(const Eigen::Matrix<double, 1, 1>& A,
const Eigen::Matrix<double, 1, 1>& B,
- const std::array<double, 1>& Qelems,
- const std::array<double, 1>& Relems,
+ const wpi::array<double, 1>& Qelems,
+ const wpi::array<double, 1>& Relems,
units::second_t dt);
LinearQuadraticRegulator(const Eigen::Matrix<double, 1, 1>& A,
@@ -266,6 +363,13 @@
const Eigen::Matrix<double, 1, 1>& R,
units::second_t dt);
+ LinearQuadraticRegulator(const Eigen::Matrix<double, 1, 1>& A,
+ const Eigen::Matrix<double, 1, 1>& B,
+ const Eigen::Matrix<double, 1, 1>& Q,
+ const Eigen::Matrix<double, 1, 1>& R,
+ const Eigen::Matrix<double, 1, 1>& N,
+ units::second_t dt);
+
LinearQuadraticRegulator(LinearQuadraticRegulator&&) = default;
LinearQuadraticRegulator& operator=(LinearQuadraticRegulator&&) = default;
};
@@ -273,20 +377,20 @@
// Template specializations are used here to make common state-input pairs
// compile faster.
template <>
-class LinearQuadraticRegulator<2, 1>
+class WPILIB_DLLEXPORT LinearQuadraticRegulator<2, 1>
: public detail::LinearQuadraticRegulatorImpl<2, 1> {
public:
template <int Outputs>
LinearQuadraticRegulator(const LinearSystem<2, 1, Outputs>& plant,
- const std::array<double, 2>& Qelems,
- const std::array<double, 1>& Relems,
+ const wpi::array<double, 2>& Qelems,
+ const wpi::array<double, 1>& Relems,
units::second_t dt)
: LinearQuadraticRegulator(plant.A(), plant.B(), Qelems, Relems, dt) {}
LinearQuadraticRegulator(const Eigen::Matrix<double, 2, 2>& A,
const Eigen::Matrix<double, 2, 1>& B,
- const std::array<double, 2>& Qelems,
- const std::array<double, 1>& Relems,
+ const wpi::array<double, 2>& Qelems,
+ const wpi::array<double, 1>& Relems,
units::second_t dt);
LinearQuadraticRegulator(const Eigen::Matrix<double, 2, 2>& A,
@@ -295,6 +399,49 @@
const Eigen::Matrix<double, 1, 1>& R,
units::second_t dt);
+ LinearQuadraticRegulator(const Eigen::Matrix<double, 2, 2>& A,
+ const Eigen::Matrix<double, 2, 1>& B,
+ const Eigen::Matrix<double, 2, 2>& Q,
+ const Eigen::Matrix<double, 1, 1>& R,
+ const Eigen::Matrix<double, 2, 1>& N,
+ units::second_t dt);
+
+ LinearQuadraticRegulator(LinearQuadraticRegulator&&) = default;
+ LinearQuadraticRegulator& operator=(LinearQuadraticRegulator&&) = default;
+};
+
+// Template specializations are used here to make common state-input pairs
+// compile faster.
+template <>
+class WPILIB_DLLEXPORT LinearQuadraticRegulator<2, 2>
+ : public detail::LinearQuadraticRegulatorImpl<2, 2> {
+ public:
+ template <int Outputs>
+ LinearQuadraticRegulator(const LinearSystem<2, 2, Outputs>& plant,
+ const wpi::array<double, 2>& Qelems,
+ const wpi::array<double, 2>& Relems,
+ units::second_t dt)
+ : LinearQuadraticRegulator(plant.A(), plant.B(), Qelems, Relems, dt) {}
+
+ LinearQuadraticRegulator(const Eigen::Matrix<double, 2, 2>& A,
+ const Eigen::Matrix<double, 2, 2>& B,
+ const wpi::array<double, 2>& Qelems,
+ const wpi::array<double, 2>& Relems,
+ units::second_t dt);
+
+ LinearQuadraticRegulator(const Eigen::Matrix<double, 2, 2>& A,
+ const Eigen::Matrix<double, 2, 2>& B,
+ const Eigen::Matrix<double, 2, 2>& Q,
+ const Eigen::Matrix<double, 2, 2>& R,
+ units::second_t dt);
+
+ LinearQuadraticRegulator(const Eigen::Matrix<double, 2, 2>& A,
+ const Eigen::Matrix<double, 2, 2>& B,
+ const Eigen::Matrix<double, 2, 2>& Q,
+ const Eigen::Matrix<double, 2, 2>& R,
+ const Eigen::Matrix<double, 2, 2>& N,
+ units::second_t dt);
+
LinearQuadraticRegulator(LinearQuadraticRegulator&&) = default;
LinearQuadraticRegulator& operator=(LinearQuadraticRegulator&&) = default;
};
diff --git a/wpimath/src/main/native/include/frc/controller/PIDController.h b/wpimath/src/main/native/include/frc/controller/PIDController.h
new file mode 100644
index 0000000..98625f8
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/controller/PIDController.h
@@ -0,0 +1,249 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <functional>
+#include <limits>
+
+#include <wpi/SymbolExports.h>
+#include <wpi/sendable/Sendable.h>
+#include <wpi/sendable/SendableHelper.h>
+
+#include "units/time.h"
+
+namespace frc2 {
+
+/**
+ * Implements a PID control loop.
+ */
+class WPILIB_DLLEXPORT PIDController
+ : public wpi::Sendable,
+ public wpi::SendableHelper<PIDController> {
+ public:
+ /**
+ * Allocates a PIDController with the given constants for Kp, Ki, and Kd.
+ *
+ * @param Kp The proportional coefficient.
+ * @param Ki The integral coefficient.
+ * @param Kd The derivative coefficient.
+ * @param period The period between controller updates in seconds. The
+ * default is 20 milliseconds. Must be non-zero and positive.
+ */
+ PIDController(double Kp, double Ki, double Kd,
+ units::second_t period = 20_ms);
+
+ ~PIDController() override = default;
+
+ PIDController(const PIDController&) = default;
+ PIDController& operator=(const PIDController&) = default;
+ PIDController(PIDController&&) = default;
+ PIDController& operator=(PIDController&&) = default;
+
+ /**
+ * Sets the PID Controller gain parameters.
+ *
+ * Sets the proportional, integral, and differential coefficients.
+ *
+ * @param Kp Proportional coefficient
+ * @param Ki Integral coefficient
+ * @param Kd Differential coefficient
+ */
+ void SetPID(double Kp, double Ki, double Kd);
+
+ /**
+ * Sets the proportional coefficient of the PID controller gain.
+ *
+ * @param Kp proportional coefficient
+ */
+ void SetP(double Kp);
+
+ /**
+ * Sets the integral coefficient of the PID controller gain.
+ *
+ * @param Ki integral coefficient
+ */
+ void SetI(double Ki);
+
+ /**
+ * Sets the differential coefficient of the PID controller gain.
+ *
+ * @param Kd differential coefficient
+ */
+ void SetD(double Kd);
+
+ /**
+ * Gets the proportional coefficient.
+ *
+ * @return proportional coefficient
+ */
+ double GetP() const;
+
+ /**
+ * Gets the integral coefficient.
+ *
+ * @return integral coefficient
+ */
+ double GetI() const;
+
+ /**
+ * Gets the differential coefficient.
+ *
+ * @return differential coefficient
+ */
+ double GetD() const;
+
+ /**
+ * Gets the period of this controller.
+ *
+ * @return The period of the controller.
+ */
+ units::second_t GetPeriod() const;
+
+ /**
+ * Sets the setpoint for the PIDController.
+ *
+ * @param setpoint The desired setpoint.
+ */
+ void SetSetpoint(double setpoint);
+
+ /**
+ * Returns the current setpoint of the PIDController.
+ *
+ * @return The current setpoint.
+ */
+ double GetSetpoint() const;
+
+ /**
+ * Returns true if the error is within the tolerance of the setpoint.
+ *
+ * This will return false until at least one input value has been computed.
+ */
+ bool AtSetpoint() const;
+
+ /**
+ * Enables continuous input.
+ *
+ * Rather then using the max and min input range as constraints, it considers
+ * them to be the same point and automatically calculates the shortest route
+ * to the setpoint.
+ *
+ * @param minimumInput The minimum value expected from the input.
+ * @param maximumInput The maximum value expected from the input.
+ */
+ void EnableContinuousInput(double minimumInput, double maximumInput);
+
+ /**
+ * Disables continuous input.
+ */
+ void DisableContinuousInput();
+
+ /**
+ * Returns true if continuous input is enabled.
+ */
+ bool IsContinuousInputEnabled() const;
+
+ /**
+ * Sets the minimum and maximum values for the integrator.
+ *
+ * When the cap is reached, the integrator value is added to the controller
+ * output rather than the integrator value times the integral gain.
+ *
+ * @param minimumIntegral The minimum value of the integrator.
+ * @param maximumIntegral The maximum value of the integrator.
+ */
+ void SetIntegratorRange(double minimumIntegral, double maximumIntegral);
+
+ /**
+ * Sets the error which is considered tolerable for use with AtSetpoint().
+ *
+ * @param positionTolerance Position error which is tolerable.
+ * @param velocityTolerance Velocity error which is tolerable.
+ */
+ void SetTolerance(
+ double positionTolerance,
+ double velocityTolerance = std::numeric_limits<double>::infinity());
+
+ /**
+ * Returns the difference between the setpoint and the measurement.
+ */
+ double GetPositionError() const;
+
+ /**
+ * Returns the velocity error.
+ */
+ double GetVelocityError() const;
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ */
+ double Calculate(double measurement);
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param setpoint The new setpoint of the controller.
+ */
+ double Calculate(double measurement, double setpoint);
+
+ /**
+ * Reset the previous error, the integral term, and disable the controller.
+ */
+ void Reset();
+
+ void InitSendable(wpi::SendableBuilder& builder) override;
+
+ private:
+ // Factor for "proportional" control
+ double m_Kp;
+
+ // Factor for "integral" control
+ double m_Ki;
+
+ // Factor for "derivative" control
+ double m_Kd;
+
+ // The period (in seconds) of the control loop running this controller
+ units::second_t m_period;
+
+ double m_maximumIntegral = 1.0;
+
+ double m_minimumIntegral = -1.0;
+
+ double m_maximumInput = 0;
+
+ double m_minimumInput = 0;
+
+ // Do the endpoints wrap around? eg. Absolute encoder
+ bool m_continuous = false;
+
+ // The error at the time of the most recent call to Calculate()
+ double m_positionError = 0;
+ double m_velocityError = 0;
+
+ // The error at the time of the second-most-recent call to Calculate() (used
+ // to compute velocity)
+ double m_prevError = 0;
+
+ // The sum of the errors for use in the integral calc
+ double m_totalError = 0;
+
+ // The error that is considered at setpoint.
+ double m_positionTolerance = 0.05;
+ double m_velocityTolerance = std::numeric_limits<double>::infinity();
+
+ double m_setpoint = 0;
+ double m_measurement = 0;
+};
+
+} // namespace frc2
+
+namespace frc {
+
+using frc2::PIDController;
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/controller/ProfiledPIDController.h b/wpimath/src/main/native/include/frc/controller/ProfiledPIDController.h
new file mode 100644
index 0000000..e0b10c7
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/controller/ProfiledPIDController.h
@@ -0,0 +1,367 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <functional>
+#include <limits>
+
+#include <wpi/SymbolExports.h>
+#include <wpi/sendable/Sendable.h>
+#include <wpi/sendable/SendableBuilder.h>
+#include <wpi/sendable/SendableHelper.h>
+
+#include "frc/MathUtil.h"
+#include "frc/controller/PIDController.h"
+#include "frc/trajectory/TrapezoidProfile.h"
+#include "units/time.h"
+
+namespace frc {
+namespace detail {
+WPILIB_DLLEXPORT
+void ReportProfiledPIDController();
+} // namespace detail
+
+/**
+ * Implements a PID control loop whose setpoint is constrained by a trapezoid
+ * profile.
+ */
+template <class Distance>
+class ProfiledPIDController
+ : public wpi::Sendable,
+ public wpi::SendableHelper<ProfiledPIDController<Distance>> {
+ public:
+ using Distance_t = units::unit_t<Distance>;
+ using Velocity =
+ units::compound_unit<Distance, units::inverse<units::seconds>>;
+ using Velocity_t = units::unit_t<Velocity>;
+ using Acceleration =
+ units::compound_unit<Velocity, units::inverse<units::seconds>>;
+ using Acceleration_t = units::unit_t<Acceleration>;
+ using State = typename TrapezoidProfile<Distance>::State;
+ using Constraints = typename TrapezoidProfile<Distance>::Constraints;
+
+ /**
+ * Allocates a ProfiledPIDController with the given constants for Kp, Ki, and
+ * Kd. Users should call reset() when they first start running the controller
+ * to avoid unwanted behavior.
+ *
+ * @param Kp The proportional coefficient.
+ * @param Ki The integral coefficient.
+ * @param Kd The derivative coefficient.
+ * @param constraints Velocity and acceleration constraints for goal.
+ * @param period The period between controller updates in seconds. The
+ * default is 20 milliseconds.
+ */
+ ProfiledPIDController(double Kp, double Ki, double Kd,
+ Constraints constraints, units::second_t period = 20_ms)
+ : m_controller(Kp, Ki, Kd, period), m_constraints(constraints) {
+ detail::ReportProfiledPIDController();
+ }
+
+ ~ProfiledPIDController() override = default;
+
+ ProfiledPIDController(const ProfiledPIDController&) = default;
+ ProfiledPIDController& operator=(const ProfiledPIDController&) = default;
+ ProfiledPIDController(ProfiledPIDController&&) = default;
+ ProfiledPIDController& operator=(ProfiledPIDController&&) = default;
+
+ /**
+ * Sets the PID Controller gain parameters.
+ *
+ * Sets the proportional, integral, and differential coefficients.
+ *
+ * @param Kp Proportional coefficient
+ * @param Ki Integral coefficient
+ * @param Kd Differential coefficient
+ */
+ void SetPID(double Kp, double Ki, double Kd) {
+ m_controller.SetPID(Kp, Ki, Kd);
+ }
+
+ /**
+ * Sets the proportional coefficient of the PID controller gain.
+ *
+ * @param Kp proportional coefficient
+ */
+ void SetP(double Kp) { m_controller.SetP(Kp); }
+
+ /**
+ * Sets the integral coefficient of the PID controller gain.
+ *
+ * @param Ki integral coefficient
+ */
+ void SetI(double Ki) { m_controller.SetI(Ki); }
+
+ /**
+ * Sets the differential coefficient of the PID controller gain.
+ *
+ * @param Kd differential coefficient
+ */
+ void SetD(double Kd) { m_controller.SetD(Kd); }
+
+ /**
+ * Gets the proportional coefficient.
+ *
+ * @return proportional coefficient
+ */
+ double GetP() const { return m_controller.GetP(); }
+
+ /**
+ * Gets the integral coefficient.
+ *
+ * @return integral coefficient
+ */
+ double GetI() const { return m_controller.GetI(); }
+
+ /**
+ * Gets the differential coefficient.
+ *
+ * @return differential coefficient
+ */
+ double GetD() const { return m_controller.GetD(); }
+
+ /**
+ * Gets the period of this controller.
+ *
+ * @return The period of the controller.
+ */
+ units::second_t GetPeriod() const { return m_controller.GetPeriod(); }
+
+ /**
+ * Sets the goal for the ProfiledPIDController.
+ *
+ * @param goal The desired unprofiled setpoint.
+ */
+ void SetGoal(State goal) { m_goal = goal; }
+
+ /**
+ * Sets the goal for the ProfiledPIDController.
+ *
+ * @param goal The desired unprofiled setpoint.
+ */
+ void SetGoal(Distance_t goal) { m_goal = {goal, Velocity_t(0)}; }
+
+ /**
+ * Gets the goal for the ProfiledPIDController.
+ */
+ State GetGoal() const { return m_goal; }
+
+ /**
+ * Returns true if the error is within the tolerance of the error.
+ *
+ * This will return false until at least one input value has been computed.
+ */
+ bool AtGoal() const { return AtSetpoint() && m_goal == m_setpoint; }
+
+ /**
+ * Set velocity and acceleration constraints for goal.
+ *
+ * @param constraints Velocity and acceleration constraints for goal.
+ */
+ void SetConstraints(Constraints constraints) { m_constraints = constraints; }
+
+ /**
+ * Returns the current setpoint of the ProfiledPIDController.
+ *
+ * @return The current setpoint.
+ */
+ State GetSetpoint() const { return m_setpoint; }
+
+ /**
+ * Returns true if the error is within the tolerance of the error.
+ *
+ * Currently this just reports on target as the actual value passes through
+ * the setpoint. Ideally it should be based on being within the tolerance for
+ * some period of time.
+ *
+ * This will return false until at least one input value has been computed.
+ */
+ bool AtSetpoint() const { return m_controller.AtSetpoint(); }
+
+ /**
+ * Enables continuous input.
+ *
+ * Rather then using the max and min input range as constraints, it considers
+ * them to be the same point and automatically calculates the shortest route
+ * to the setpoint.
+ *
+ * @param minimumInput The minimum value expected from the input.
+ * @param maximumInput The maximum value expected from the input.
+ */
+ void EnableContinuousInput(Distance_t minimumInput, Distance_t maximumInput) {
+ m_controller.EnableContinuousInput(minimumInput.value(),
+ maximumInput.value());
+ m_minimumInput = minimumInput;
+ m_maximumInput = maximumInput;
+ }
+
+ /**
+ * Disables continuous input.
+ */
+ void DisableContinuousInput() { m_controller.DisableContinuousInput(); }
+
+ /**
+ * Sets the minimum and maximum values for the integrator.
+ *
+ * When the cap is reached, the integrator value is added to the controller
+ * output rather than the integrator value times the integral gain.
+ *
+ * @param minimumIntegral The minimum value of the integrator.
+ * @param maximumIntegral The maximum value of the integrator.
+ */
+ void SetIntegratorRange(double minimumIntegral, double maximumIntegral) {
+ m_controller.SetIntegratorRange(minimumIntegral, maximumIntegral);
+ }
+
+ /**
+ * Sets the error which is considered tolerable for use with
+ * AtSetpoint().
+ *
+ * @param positionTolerance Position error which is tolerable.
+ * @param velocityTolerance Velocity error which is tolerable.
+ */
+ void SetTolerance(
+ Distance_t positionTolerance,
+ Velocity_t velocityTolerance = std::numeric_limits<double>::infinity()) {
+ m_controller.SetTolerance(positionTolerance.value(),
+ velocityTolerance.value());
+ }
+
+ /**
+ * Returns the difference between the setpoint and the measurement.
+ *
+ * @return The error.
+ */
+ Distance_t GetPositionError() const {
+ return Distance_t(m_controller.GetPositionError());
+ }
+
+ /**
+ * Returns the change in error per second.
+ */
+ Velocity_t GetVelocityError() const {
+ return Velocity_t(m_controller.GetVelocityError());
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ */
+ double Calculate(Distance_t measurement) {
+ if (m_controller.IsContinuousInputEnabled()) {
+ // Get error which is smallest distance between goal and measurement
+ auto errorBound = (m_maximumInput - m_minimumInput) / 2.0;
+ auto goalMinDistance = frc::InputModulus<Distance_t>(
+ m_goal.position - measurement, -errorBound, errorBound);
+ auto setpointMinDistance = frc::InputModulus<Distance_t>(
+ m_setpoint.position - measurement, -errorBound, errorBound);
+
+ // Recompute the profile goal with the smallest error, thus giving the
+ // shortest path. The goal may be outside the input range after this
+ // operation, but that's OK because the controller will still go there and
+ // report an error of zero. In other words, the setpoint only needs to be
+ // offset from the measurement by the input range modulus; they don't need
+ // to be equal.
+ m_goal.position = goalMinDistance + measurement;
+ m_setpoint.position = setpointMinDistance + measurement;
+ }
+
+ frc::TrapezoidProfile<Distance> profile{m_constraints, m_goal, m_setpoint};
+ m_setpoint = profile.Calculate(GetPeriod());
+ return m_controller.Calculate(measurement.value(),
+ m_setpoint.position.value());
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param goal The new goal of the controller.
+ */
+ double Calculate(Distance_t measurement, State goal) {
+ SetGoal(goal);
+ return Calculate(measurement);
+ }
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param goal The new goal of the controller.
+ */
+ double Calculate(Distance_t measurement, Distance_t goal) {
+ SetGoal(goal);
+ return Calculate(measurement);
+ }
+
+ /**
+ * Returns the next output of the PID controller.
+ *
+ * @param measurement The current measurement of the process variable.
+ * @param goal The new goal of the controller.
+ * @param constraints Velocity and acceleration constraints for goal.
+ */
+ double Calculate(
+ Distance_t measurement, Distance_t goal,
+ typename frc::TrapezoidProfile<Distance>::Constraints constraints) {
+ SetConstraints(constraints);
+ return Calculate(measurement, goal);
+ }
+
+ /**
+ * Reset the previous error and the integral term.
+ *
+ * @param measurement The current measured State of the system.
+ */
+ void Reset(const State& measurement) {
+ m_controller.Reset();
+ m_setpoint = measurement;
+ }
+
+ /**
+ * Reset the previous error and the integral term.
+ *
+ * @param measuredPosition The current measured position of the system.
+ * @param measuredVelocity The current measured velocity of the system.
+ */
+ void Reset(Distance_t measuredPosition, Velocity_t measuredVelocity) {
+ Reset(State{measuredPosition, measuredVelocity});
+ }
+
+ /**
+ * Reset the previous error and the integral term.
+ *
+ * @param measuredPosition The current measured position of the system. The
+ * velocity is assumed to be zero.
+ */
+ void Reset(Distance_t measuredPosition) {
+ Reset(measuredPosition, Velocity_t(0));
+ }
+
+ void InitSendable(wpi::SendableBuilder& builder) override {
+ builder.SetSmartDashboardType("ProfiledPIDController");
+ builder.AddDoubleProperty(
+ "p", [this] { return GetP(); }, [this](double value) { SetP(value); });
+ builder.AddDoubleProperty(
+ "i", [this] { return GetI(); }, [this](double value) { SetI(value); });
+ builder.AddDoubleProperty(
+ "d", [this] { return GetD(); }, [this](double value) { SetD(value); });
+ builder.AddDoubleProperty(
+ "goal", [this] { return GetGoal().position.value(); },
+ [this](double value) { SetGoal(Distance_t{value}); });
+ }
+
+ private:
+ frc2::PIDController m_controller;
+ Distance_t m_minimumInput{0};
+ Distance_t m_maximumInput{0};
+ typename frc::TrapezoidProfile<Distance>::State m_goal;
+ typename frc::TrapezoidProfile<Distance>::State m_setpoint;
+ typename frc::TrapezoidProfile<Distance>::Constraints m_constraints;
+};
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/controller/RamseteController.h b/wpimath/src/main/native/include/frc/controller/RamseteController.h
new file mode 100644
index 0000000..022fff9
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/controller/RamseteController.h
@@ -0,0 +1,120 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <wpi/SymbolExports.h>
+
+#include "frc/geometry/Pose2d.h"
+#include "frc/kinematics/ChassisSpeeds.h"
+#include "frc/trajectory/Trajectory.h"
+#include "units/angular_velocity.h"
+#include "units/velocity.h"
+
+namespace frc {
+
+/**
+ * Ramsete is a nonlinear time-varying feedback controller for unicycle models
+ * that drives the model to a desired pose along a two-dimensional trajectory.
+ * Why would we need a nonlinear control law in addition to the linear ones we
+ * have used so far like PID? If we use the original approach with PID
+ * controllers for left and right position and velocity states, the controllers
+ * only deal with the local pose. If the robot deviates from the path, there is
+ * no way for the controllers to correct and the robot may not reach the desired
+ * global pose. This is due to multiple endpoints existing for the robot which
+ * have the same encoder path arc lengths.
+ *
+ * Instead of using wheel path arc lengths (which are in the robot's local
+ * coordinate frame), nonlinear controllers like pure pursuit and Ramsete use
+ * global pose. The controller uses this extra information to guide a linear
+ * reference tracker like the PID controllers back in by adjusting the
+ * references of the PID controllers.
+ *
+ * The paper "Control of Wheeled Mobile Robots: An Experimental Overview"
+ * describes a nonlinear controller for a wheeled vehicle with unicycle-like
+ * kinematics; a global pose consisting of x, y, and theta; and a desired pose
+ * consisting of x_d, y_d, and theta_d. We call it Ramsete because that's the
+ * acronym for the title of the book it came from in Italian ("Robotica
+ * Articolata e Mobile per i SErvizi e le TEcnologie").
+ *
+ * See <https://file.tavsys.net/control/controls-engineering-in-frc.pdf> section
+ * on Ramsete unicycle controller for a derivation and analysis.
+ */
+class WPILIB_DLLEXPORT RamseteController {
+ public:
+ /**
+ * Construct a Ramsete unicycle controller.
+ *
+ * @param b Tuning parameter (b > 0) for which larger values make
+ * convergence more aggressive like a proportional term.
+ * @param zeta Tuning parameter (0 < zeta < 1) for which larger values provide
+ * more damping in response.
+ */
+ RamseteController(double b, double zeta);
+
+ /**
+ * Construct a Ramsete unicycle controller. The default arguments for
+ * b and zeta of 2.0 and 0.7 have been well-tested to produce desirable
+ * results.
+ */
+ RamseteController() : RamseteController(2.0, 0.7) {}
+
+ /**
+ * Returns true if the pose error is within tolerance of the reference.
+ */
+ bool AtReference() const;
+
+ /**
+ * Sets the pose error which is considered tolerable for use with
+ * AtReference().
+ *
+ * @param poseTolerance Pose error which is tolerable.
+ */
+ void SetTolerance(const Pose2d& poseTolerance);
+
+ /**
+ * Returns the next output of the Ramsete controller.
+ *
+ * The reference pose, linear velocity, and angular velocity should come from
+ * a drivetrain trajectory.
+ *
+ * @param currentPose The current pose.
+ * @param poseRef The desired pose.
+ * @param linearVelocityRef The desired linear velocity.
+ * @param angularVelocityRef The desired angular velocity.
+ */
+ ChassisSpeeds Calculate(const Pose2d& currentPose, const Pose2d& poseRef,
+ units::meters_per_second_t linearVelocityRef,
+ units::radians_per_second_t angularVelocityRef);
+
+ /**
+ * Returns the next output of the Ramsete controller.
+ *
+ * The reference pose, linear velocity, and angular velocity should come from
+ * a drivetrain trajectory.
+ *
+ * @param currentPose The current pose.
+ * @param desiredState The desired pose, linear velocity, and angular velocity
+ * from a trajectory.
+ */
+ ChassisSpeeds Calculate(const Pose2d& currentPose,
+ const Trajectory::State& desiredState);
+
+ /**
+ * Enables and disables the controller for troubleshooting purposes.
+ *
+ * @param enabled If the controller is enabled or not.
+ */
+ void SetEnabled(bool enabled);
+
+ private:
+ double m_b;
+ double m_zeta;
+
+ Pose2d m_poseError;
+ Pose2d m_poseTolerance;
+ bool m_enabled = true;
+};
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/controller/SimpleMotorFeedforward.h b/wpimath/src/main/native/include/frc/controller/SimpleMotorFeedforward.h
index 1afab40..df1a52c 100644
--- a/wpimath/src/main/native/include/frc/controller/SimpleMotorFeedforward.h
+++ b/wpimath/src/main/native/include/frc/controller/SimpleMotorFeedforward.h
@@ -1,14 +1,14 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <wpi/MathExtras.h>
+#include "Eigen/Core"
+#include "frc/controller/LinearPlantInversionFeedforward.h"
+#include "frc/system/plant/LinearSystemId.h"
#include "units/time.h"
#include "units/voltage.h"
@@ -55,6 +55,28 @@
return kS * wpi::sgn(velocity) + kV * velocity + kA * acceleration;
}
+ /**
+ * Calculates the feedforward from the gains and setpoints.
+ *
+ * @param currentVelocity The current velocity setpoint, in distance per
+ * second.
+ * @param nextVelocity The next velocity setpoint, in distance per second.
+ * @param dt Time between velocity setpoints in seconds.
+ * @return The computed feedforward, in volts.
+ */
+ units::volt_t Calculate(units::unit_t<Velocity> currentVelocity,
+ units::unit_t<Velocity> nextVelocity,
+ units::second_t dt) const {
+ auto plant = LinearSystemId::IdentifyVelocitySystem<Distance>(kV, kA);
+ LinearPlantInversionFeedforward<1, 1> feedforward{plant, dt};
+
+ Eigen::Vector<double, 1> r{currentVelocity.value()};
+ Eigen::Vector<double, 1> nextR{nextVelocity.value()};
+
+ return kS * wpi::sgn(currentVelocity.value()) +
+ units::volt_t{feedforward.Calculate(r, nextR)(0)};
+ }
+
// Rearranging the main equation from the calculate() method yields the
// formulas for the methods below:
diff --git a/wpimath/src/main/native/include/frc/estimator/AngleStatistics.h b/wpimath/src/main/native/include/frc/estimator/AngleStatistics.h
new file mode 100644
index 0000000..3ddabc1
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/estimator/AngleStatistics.h
@@ -0,0 +1,128 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <wpi/numbers>
+
+#include "Eigen/Core"
+#include "frc/MathUtil.h"
+
+namespace frc {
+
+/**
+ * Subtracts a and b while normalizing the resulting value in the selected row
+ * as if it were an angle.
+ *
+ * @tparam States The number of states.
+ * @param a A vector to subtract from.
+ * @param b A vector to subtract with.
+ * @param angleStateIdx The row containing angles to be normalized.
+ */
+template <int States>
+Eigen::Vector<double, States> AngleResidual(
+ const Eigen::Vector<double, States>& a,
+ const Eigen::Vector<double, States>& b, int angleStateIdx) {
+ Eigen::Vector<double, States> ret = a - b;
+ ret[angleStateIdx] =
+ AngleModulus(units::radian_t{ret[angleStateIdx]}).value();
+ return ret;
+}
+
+/**
+ * Returns a function that subtracts two vectors while normalizing the resulting
+ * value in the selected row as if it were an angle.
+ *
+ * @tparam States The number of states.
+ * @param angleStateIdx The row containing angles to be normalized.
+ */
+template <int States>
+std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&, const Eigen::Vector<double, States>&)>
+AngleResidual(int angleStateIdx) {
+ return [=](auto a, auto b) {
+ return AngleResidual<States>(a, b, angleStateIdx);
+ };
+}
+
+/**
+ * Adds a and b while normalizing the resulting value in the selected row as an
+ * angle.
+ *
+ * @tparam States The number of states.
+ * @param a A vector to add with.
+ * @param b A vector to add with.
+ * @param angleStateIdx The row containing angles to be normalized.
+ */
+template <int States>
+Eigen::Vector<double, States> AngleAdd(const Eigen::Vector<double, States>& a,
+ const Eigen::Vector<double, States>& b,
+ int angleStateIdx) {
+ Eigen::Vector<double, States> ret = a + b;
+ ret[angleStateIdx] =
+ InputModulus(ret[angleStateIdx], -wpi::numbers::pi, wpi::numbers::pi);
+ return ret;
+}
+
+/**
+ * Returns a function that adds two vectors while normalizing the resulting
+ * value in the selected row as an angle.
+ *
+ * @tparam States The number of states.
+ * @param angleStateIdx The row containing angles to be normalized.
+ */
+template <int States>
+std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&, const Eigen::Vector<double, States>&)>
+AngleAdd(int angleStateIdx) {
+ return [=](auto a, auto b) { return AngleAdd<States>(a, b, angleStateIdx); };
+}
+
+/**
+ * Computes the mean of sigmas with the weights Wm while computing a special
+ * angle mean for a select row.
+ *
+ * @tparam CovDim Dimension of covariance of sigma points after passing through
+ * the transform.
+ * @tparam States The number of states.
+ * @param sigmas Sigma points.
+ * @param Wm Weights for the mean.
+ * @param angleStatesIdx The row containing the angles.
+ */
+template <int CovDim, int States>
+Eigen::Vector<double, CovDim> AngleMean(
+ const Eigen::Matrix<double, CovDim, 2 * States + 1>& sigmas,
+ const Eigen::Vector<double, 2 * States + 1>& Wm, int angleStatesIdx) {
+ double sumSin = sigmas.row(angleStatesIdx)
+ .unaryExpr([](auto it) { return std::sin(it); })
+ .sum();
+ double sumCos = sigmas.row(angleStatesIdx)
+ .unaryExpr([](auto it) { return std::cos(it); })
+ .sum();
+
+ Eigen::Vector<double, CovDim> ret = sigmas * Wm;
+ ret[angleStatesIdx] = std::atan2(sumSin, sumCos);
+ return ret;
+}
+
+/**
+ * Returns a function that computes the mean of sigmas with the weights Wm while
+ * computing a special angle mean for a select row.
+ *
+ * @tparam CovDim Dimension of covariance of sigma points after passing through
+ * the transform.
+ * @tparam States The number of states.
+ * @param angleStateIdx The row containing the angles.
+ */
+template <int CovDim, int States>
+std::function<Eigen::Vector<double, CovDim>(
+ const Eigen::Matrix<double, CovDim, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+AngleMean(int angleStateIdx) {
+ return [=](auto sigmas, auto Wm) {
+ return AngleMean<CovDim, States>(sigmas, Wm, angleStateIdx);
+ };
+}
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/estimator/DifferentialDrivePoseEstimator.h b/wpimath/src/main/native/include/frc/estimator/DifferentialDrivePoseEstimator.h
new file mode 100644
index 0000000..b957c1e
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/estimator/DifferentialDrivePoseEstimator.h
@@ -0,0 +1,242 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
+
+#include "Eigen/Core"
+#include "frc/estimator/KalmanFilterLatencyCompensator.h"
+#include "frc/estimator/UnscentedKalmanFilter.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/geometry/Rotation2d.h"
+#include "frc/kinematics/DifferentialDriveWheelSpeeds.h"
+#include "units/time.h"
+
+namespace frc {
+/**
+ * This class wraps an Unscented Kalman Filter to fuse latency-compensated
+ * vision measurements with differential drive encoder measurements. It will
+ * correct for noisy vision measurements and encoder drift. It is intended to be
+ * an easy drop-in for DifferentialDriveOdometry. In fact, if you never call
+ * AddVisionMeasurement(), and only call Update(), this will behave exactly the
+ * same as DifferentialDriveOdometry.
+ *
+ * Update() should be called every robot loop (if your robot loops are faster or
+ * slower than the default, then you should change the nominal delta time via
+ * the constructor).
+ *
+ * AddVisionMeasurement() can be called as infrequently as you want; if you
+ * never call it, then this class will behave like regular encoder odometry.
+ *
+ * The state-space system used internally has the following states (x), inputs
+ * (u), and outputs (y):
+ *
+ * <strong> x = [x, y, theta, dist_l, dist_r]ᵀ </strong> in the field coordinate
+ * system containing x position, y position, heading, left encoder distance,
+ * and right encoder distance.
+ *
+ * <strong> u = [v_l, v_r, dtheta]ᵀ </strong> containing left wheel velocity,
+ * right wheel velocity, and change in gyro heading.
+ *
+ * NB: Using velocities make things considerably easier, because it means that
+ * teams don't have to worry about getting an accurate model. Basically, we
+ * suspect that it's easier for teams to get good encoder data than it is for
+ * them to perform system identification well enough to get a good model.
+ *
+ * <strong> y = [x, y, theta]ᵀ </strong> from vision containing x position, y
+ * position, and heading; or <strong>y = [dist_l, dist_r, theta] </strong>
+ * containing left encoder position, right encoder position, and gyro heading.
+ */
+class WPILIB_DLLEXPORT DifferentialDrivePoseEstimator {
+ public:
+ /**
+ * Constructs a DifferentialDrivePoseEstimator.
+ *
+ * @param gyroAngle The gyro angle of the robot.
+ * @param initialPose The estimated initial pose.
+ * @param stateStdDevs Standard deviations of model states.
+ * Increase these numbers to trust your
+ * model's state estimates less. This matrix
+ * is in the form
+ * [x, y, theta, dist_l, dist_r]ᵀ,
+ * with units in meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro
+ * measurements. Increase these numbers to
+ * trust sensor readings from
+ * encoders and gyros less.
+ * This matrix is in the form
+ * [dist_l, dist_r, theta]ᵀ, with units in
+ * meters and radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from
+ * vision less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ * @param nominalDt The period of the loop calling Update().
+ */
+ DifferentialDrivePoseEstimator(
+ const Rotation2d& gyroAngle, const Pose2d& initialPose,
+ const wpi::array<double, 5>& stateStdDevs,
+ const wpi::array<double, 3>& localMeasurementStdDevs,
+ const wpi::array<double, 3>& visionMeasurementStdDevs,
+ units::second_t nominalDt = 0.02_s);
+
+ /**
+ * Sets the pose estimator's trust of global measurements. This might be used
+ * to change trust in vision measurements after the autonomous period, or to
+ * change trust as distance to a vision target increases.
+ *
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ */
+ void SetVisionMeasurementStdDevs(
+ const wpi::array<double, 3>& visionMeasurementStdDevs);
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * You NEED to reset your encoders to zero when calling this method. The
+ * gyroscope angle does not need to be reset here on the user's robot code.
+ * The library automatically takes care of offsetting the gyro angle.
+ *
+ * @param pose The estimated pose of the robot on the field.
+ * @param gyroAngle The current gyro angle.
+ */
+ void ResetPosition(const Pose2d& pose, const Rotation2d& gyroAngle);
+
+ /**
+ * Returns the pose of the robot at the current time as estimated by the
+ * Unscented Kalman Filter.
+ *
+ * @return The estimated robot pose.
+ */
+ Pose2d GetEstimatedPosition() const;
+
+ /**
+ * Adds a vision measurement to the Unscented Kalman Filter. This will correct
+ * the odometry pose estimate while still accounting for measurement noise.
+ *
+ * This method can be called as infrequently as you want, as long as you are
+ * calling Update() every loop.
+ *
+ * @param visionRobotPose The pose of the robot as measured by the vision
+ * camera.
+ * @param timestamp The timestamp of the vision measurement in seconds.
+ * Note that if you don't use your own time source by
+ * calling UpdateWithTime(), then you must use a
+ * timestamp with an epoch since FPGA startup (i.e. the
+ * epoch of this timestamp is the same epoch as
+ * frc::Timer::GetFPGATimestamp(). This means that
+ * you should use frc::Timer::GetFPGATimestamp() as
+ * your time source in this case.
+ */
+ void AddVisionMeasurement(const Pose2d& visionRobotPose,
+ units::second_t timestamp);
+
+ /**
+ * Adds a vision measurement to the Unscented Kalman Filter. This will correct
+ * the odometry pose estimate while still accounting for measurement noise.
+ *
+ * This method can be called as infrequently as you want, as long as you are
+ * calling Update() every loop.
+ *
+ * Note that the vision measurement standard deviations passed into this
+ * method will continue to apply to future measurements until a subsequent
+ * call to SetVisionMeasurementStdDevs() or this method.
+ *
+ * @param visionRobotPose The pose of the robot as measured by the
+ * vision camera.
+ * @param timestamp The timestamp of the vision measurement in
+ * seconds. Note that if you don't use your
+ * own time source by calling
+ * UpdateWithTime(), then you must use a
+ * timestamp with an epoch since FPGA startup
+ * (i.e. the epoch of this timestamp is the
+ * same epoch as
+ * frc::Timer::GetFPGATimestamp(). This means
+ * that you should use
+ * frc::Timer::GetFPGATimestamp() as your
+ * time source in this case.
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ */
+ void AddVisionMeasurement(
+ const Pose2d& visionRobotPose, units::second_t timestamp,
+ const wpi::array<double, 3>& visionMeasurementStdDevs) {
+ SetVisionMeasurementStdDevs(visionMeasurementStdDevs);
+ AddVisionMeasurement(visionRobotPose, timestamp);
+ }
+
+ /**
+ * Updates the Unscented Kalman Filter using only wheel encoder information.
+ * Note that this should be called every loop iteration.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param wheelSpeeds The velocities of the wheels in meters per second.
+ * @param leftDistance The distance traveled by the left encoder.
+ * @param rightDistance The distance traveled by the right encoder.
+ *
+ * @return The estimated pose of the robot.
+ */
+ Pose2d Update(const Rotation2d& gyroAngle,
+ const DifferentialDriveWheelSpeeds& wheelSpeeds,
+ units::meter_t leftDistance, units::meter_t rightDistance);
+
+ /**
+ * Updates the Unscented Kalman Filter using only wheel encoder information.
+ * Note that this should be called every loop iteration.
+ *
+ * @param currentTime The time at which this method was called.
+ * @param gyroAngle The current gyro angle.
+ * @param wheelSpeeds The velocities of the wheels in meters per second.
+ * @param leftDistance The distance traveled by the left encoder.
+ * @param rightDistance The distance traveled by the right encoder.
+ *
+ * @return The estimated pose of the robot.
+ */
+ Pose2d UpdateWithTime(units::second_t currentTime,
+ const Rotation2d& gyroAngle,
+ const DifferentialDriveWheelSpeeds& wheelSpeeds,
+ units::meter_t leftDistance,
+ units::meter_t rightDistance);
+
+ private:
+ UnscentedKalmanFilter<5, 3, 3> m_observer;
+ KalmanFilterLatencyCompensator<5, 3, 3, UnscentedKalmanFilter<5, 3, 3>>
+ m_latencyCompensator;
+ std::function<void(const Eigen::Vector<double, 3>& u,
+ const Eigen::Vector<double, 3>& y)>
+ m_visionCorrect;
+
+ Eigen::Matrix<double, 3, 3> m_visionContR;
+
+ units::second_t m_nominalDt;
+ units::second_t m_prevTime = -1_s;
+
+ Rotation2d m_gyroOffset;
+ Rotation2d m_previousAngle;
+
+ template <int Dim>
+ static wpi::array<double, Dim> StdDevMatrixToArray(
+ const Eigen::Vector<double, Dim>& stdDevs);
+
+ static Eigen::Vector<double, 5> F(const Eigen::Vector<double, 5>& x,
+ const Eigen::Vector<double, 3>& u);
+ static Eigen::Vector<double, 5> FillStateVector(const Pose2d& pose,
+ units::meter_t leftDistance,
+ units::meter_t rightDistance);
+};
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h
index 6f0fc85..3e5edb8 100644
--- a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h
+++ b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h
@@ -1,26 +1,48 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
#include <functional>
+#include <wpi/array.h>
+
+#include "Eigen/Cholesky"
#include "Eigen/Core"
-#include "Eigen/src/Cholesky/LDLT.h"
#include "drake/math/discrete_algebraic_riccati_equation.h"
#include "frc/StateSpaceUtil.h"
#include "frc/system/Discretization.h"
+#include "frc/system/NumericalIntegration.h"
#include "frc/system/NumericalJacobian.h"
-#include "frc/system/RungeKutta.h"
#include "units/time.h"
namespace frc {
+/**
+ * A Kalman filter combines predictions from a model and measurements to give an
+ * estimate of the true system state. This is useful because many states cannot
+ * be measured directly as a result of sensor noise, or because the state is
+ * "hidden".
+ *
+ * Kalman filters use a K gain matrix to determine whether to trust the model or
+ * measurements more. Kalman filter theory uses statistics to compute an optimal
+ * K gain which minimizes the sum of squares error in the state estimate. This K
+ * gain is used to correct the state estimate by some amount of the difference
+ * between the actual measurements and the measurements predicted by the model.
+ *
+ * An extended Kalman filter supports nonlinear state and measurement models. It
+ * propagates the error covariance by linearizing the models around the state
+ * estimate, then applying the linear Kalman filter equations.
+ *
+ * For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9
+ * "Stochastic control theory".
+ *
+ * @tparam States The number of states.
+ * @tparam Inputs The number of inputs.
+ * @tparam Outputs The number of outputs.
+ */
template <int States, int Inputs, int Outputs>
class ExtendedKalmanFilter {
public:
@@ -35,30 +57,36 @@
* @param measurementStdDevs Standard deviations of measurements.
* @param dt Nominal discretization timestep.
*/
- ExtendedKalmanFilter(std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ ExtendedKalmanFilter(std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
f,
- std::function<Eigen::Matrix<double, Outputs, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
h,
- const std::array<double, States>& stateStdDevs,
- const std::array<double, Outputs>& measurementStdDevs,
+ const wpi::array<double, States>& stateStdDevs,
+ const wpi::array<double, Outputs>& measurementStdDevs,
units::second_t dt)
: m_f(f), m_h(h) {
m_contQ = MakeCovMatrix(stateStdDevs);
m_contR = MakeCovMatrix(measurementStdDevs);
+ m_residualFuncY = [](auto a, auto b) -> Eigen::Vector<double, Outputs> {
+ return a - b;
+ };
+ m_addFuncX = [](auto a, auto b) -> Eigen::Vector<double, States> {
+ return a + b;
+ };
m_dt = dt;
Reset();
Eigen::Matrix<double, States, States> contA =
NumericalJacobianX<States, States, Inputs>(
- m_f, m_xHat, Eigen::Matrix<double, Inputs, 1>::Zero());
+ m_f, m_xHat, Eigen::Vector<double, Inputs>::Zero());
Eigen::Matrix<double, Outputs, States> C =
NumericalJacobianX<Outputs, States, Inputs>(
- m_h, m_xHat, Eigen::Matrix<double, Inputs, 1>::Zero());
+ m_h, m_xHat, Eigen::Vector<double, Inputs>::Zero());
Eigen::Matrix<double, States, States> discA;
Eigen::Matrix<double, States, States> discQ;
@@ -67,10 +95,70 @@
Eigen::Matrix<double, Outputs, Outputs> discR =
DiscretizeR<Outputs>(m_contR, dt);
- // IsStabilizable(A^T, C^T) will tell us if the system is observable.
- bool isObservable =
- IsStabilizable<States, Outputs>(discA.transpose(), C.transpose());
- if (isObservable && Outputs <= States) {
+ if (IsDetectable<States, Outputs>(discA, C) && Outputs <= States) {
+ m_initP = drake::math::DiscreteAlgebraicRiccatiEquation(
+ discA.transpose(), C.transpose(), discQ, discR);
+ } else {
+ m_initP = Eigen::Matrix<double, States, States>::Zero();
+ }
+ m_P = m_initP;
+ }
+
+ /**
+ * Constructs an extended Kalman filter.
+ *
+ * @param f A vector-valued function of x and u that returns
+ * the derivative of the state vector.
+ * @param h A vector-valued function of x and u that returns
+ * the measurement vector.
+ * @param stateStdDevs Standard deviations of model states.
+ * @param measurementStdDevs Standard deviations of measurements.
+ * @param residualFuncY A function that computes the residual of two
+ * measurement vectors (i.e. it subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ * @param dt Nominal discretization timestep.
+ */
+ ExtendedKalmanFilter(std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
+ f,
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
+ h,
+ const wpi::array<double, States>& stateStdDevs,
+ const wpi::array<double, Outputs>& measurementStdDevs,
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, Outputs>&,
+ const Eigen::Vector<double, Outputs>&)>
+ residualFuncY,
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>&)>
+ addFuncX,
+ units::second_t dt)
+ : m_f(f), m_h(h), m_residualFuncY(residualFuncY), m_addFuncX(addFuncX) {
+ m_contQ = MakeCovMatrix(stateStdDevs);
+ m_contR = MakeCovMatrix(measurementStdDevs);
+ m_dt = dt;
+
+ Reset();
+
+ Eigen::Matrix<double, States, States> contA =
+ NumericalJacobianX<States, States, Inputs>(
+ m_f, m_xHat, Eigen::Vector<double, Inputs>::Zero());
+ Eigen::Matrix<double, Outputs, States> C =
+ NumericalJacobianX<Outputs, States, Inputs>(
+ m_h, m_xHat, Eigen::Vector<double, Inputs>::Zero());
+
+ Eigen::Matrix<double, States, States> discA;
+ Eigen::Matrix<double, States, States> discQ;
+ DiscretizeAQTaylor<States>(contA, m_contQ, dt, &discA, &discQ);
+
+ Eigen::Matrix<double, Outputs, Outputs> discR =
+ DiscretizeR<Outputs>(m_contR, dt);
+
+ if (IsDetectable<States, Outputs>(discA, C) && Outputs <= States) {
m_initP = drake::math::DiscreteAlgebraicRiccatiEquation(
discA.transpose(), C.transpose(), discQ, discR);
} else {
@@ -102,7 +190,7 @@
/**
* Returns the state estimate x-hat.
*/
- const Eigen::Matrix<double, States, 1>& Xhat() const { return m_xHat; }
+ const Eigen::Vector<double, States>& Xhat() const { return m_xHat; }
/**
* Returns an element of the state estimate x-hat.
@@ -116,7 +204,7 @@
*
* @param xHat The state estimate x-hat.
*/
- void SetXhat(const Eigen::Matrix<double, States, 1>& xHat) { m_xHat = xHat; }
+ void SetXhat(const Eigen::Vector<double, States>& xHat) { m_xHat = xHat; }
/**
* Set an element of the initial state estimate x-hat.
@@ -140,9 +228,7 @@
* @param u New control input from controller.
* @param dt Timestep for prediction.
*/
- void Predict(const Eigen::Matrix<double, Inputs, 1>& u, units::second_t dt) {
- m_dt = dt;
-
+ void Predict(const Eigen::Vector<double, Inputs>& u, units::second_t dt) {
// Find continuous A
Eigen::Matrix<double, States, States> contA =
NumericalJacobianX<States, States, Inputs>(m_f, m_xHat, u);
@@ -152,8 +238,12 @@
Eigen::Matrix<double, States, States> discQ;
DiscretizeAQTaylor<States>(contA, m_contQ, dt, &discA, &discQ);
- m_xHat = RungeKutta(m_f, m_xHat, u, dt);
+ m_xHat = RK4(m_f, m_xHat, u, dt);
+
+ // Pₖ₊₁⁻ = APₖ⁻Aᵀ + Q
m_P = discA * m_P * discA.transpose() + discQ;
+
+ m_dt = dt;
}
/**
@@ -162,9 +252,26 @@
* @param u Same control input used in the predict step.
* @param y Measurement vector.
*/
- void Correct(const Eigen::Matrix<double, Inputs, 1>& u,
- const Eigen::Matrix<double, Outputs, 1>& y) {
- Correct<Outputs>(u, y, m_h, m_contR);
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Outputs>& y) {
+ Correct<Outputs>(u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
+ }
+
+ template <int Rows>
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Rows>& y,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
+ h,
+ const Eigen::Matrix<double, Rows, Rows>& R) {
+ auto residualFuncY = [](auto a, auto b) -> Eigen::Vector<double, Rows> {
+ return a - b;
+ };
+ auto addFuncX = [](auto a, auto b) -> Eigen::Vector<double, States> {
+ return a + b;
+ };
+ Correct<Rows>(u, y, h, R, residualFuncY, addFuncX);
}
/**
@@ -174,55 +281,77 @@
* Correct() call vary. The h(x, u) passed to the constructor is used if one
* is not provided (the two-argument version of this function).
*
- * @param u Same control input used in the predict step.
- * @param y Measurement vector.
- * @param h A vector-valued function of x and u that returns
- * the measurement vector.
- * @param R Discrete measurement noise covariance matrix.
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ * @param h A vector-valued function of x and u that returns
+ * the measurement vector.
+ * @param R Discrete measurement noise covariance matrix.
+ * @param residualFuncY A function that computes the residual of two
+ * measurement vectors (i.e. it subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
*/
template <int Rows>
- void Correct(const Eigen::Matrix<double, Inputs, 1>& u,
- const Eigen::Matrix<double, Rows, 1>& y,
- std::function<Eigen::Matrix<double, Rows, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Rows>& y,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
h,
- const Eigen::Matrix<double, Rows, Rows>& R) {
+ const Eigen::Matrix<double, Rows, Rows>& R,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Vector<double, Rows>&,
+ const Eigen::Vector<double, Rows>&)>
+ residualFuncY,
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>)>
+ addFuncX) {
const Eigen::Matrix<double, Rows, States> C =
NumericalJacobianX<Rows, States, Inputs>(h, m_xHat, u);
const Eigen::Matrix<double, Rows, Rows> discR = DiscretizeR<Rows>(R, m_dt);
Eigen::Matrix<double, Rows, Rows> S = C * m_P * C.transpose() + discR;
- // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
+ // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
// efficiently.
//
- // K = PC^T S^-1
- // KS = PC^T
- // (KS)^T = (PC^T)^T
- // S^T K^T = CP^T
+ // K = PCᵀS⁻¹
+ // KS = PCᵀ
+ // (KS)ᵀ = (PCᵀ)ᵀ
+ // SᵀKᵀ = CPᵀ
//
// The solution of Ax = b can be found via x = A.solve(b).
//
- // K^T = S^T.solve(CP^T)
- // K = (S^T.solve(CP^T))^T
+ // Kᵀ = Sᵀ.solve(CPᵀ)
+ // K = (Sᵀ.solve(CPᵀ))ᵀ
Eigen::Matrix<double, States, Rows> K =
S.transpose().ldlt().solve(C * m_P.transpose()).transpose();
- m_xHat += K * (y - h(m_xHat, u));
+ // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − h(x̂ₖ₊₁⁻, uₖ₊₁))
+ m_xHat = addFuncX(m_xHat, K * residualFuncY(y, h(m_xHat, u)));
+
+ // Pₖ₊₁⁺ = (I − KC)Pₖ₊₁⁻
m_P = (Eigen::Matrix<double, States, States>::Identity() - K * C) * m_P;
}
private:
- std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
m_f;
- std::function<Eigen::Matrix<double, Outputs, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
m_h;
- Eigen::Matrix<double, States, 1> m_xHat;
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, Outputs>&,
+ const Eigen::Vector<double, Outputs>)>
+ m_residualFuncY;
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>)>
+ m_addFuncX;
+ Eigen::Vector<double, States> m_xHat;
Eigen::Matrix<double, States, States> m_P;
Eigen::Matrix<double, States, States> m_contQ;
Eigen::Matrix<double, Outputs, Outputs> m_contR;
diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h
index c395080..3aa4dbd 100644
--- a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h
+++ b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h
@@ -1,17 +1,19 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2018-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
-#include <cmath>
+#include <frc/fmt/Eigen.h>
+#include <cmath>
+#include <string>
+
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
+
+#include "Eigen/Cholesky"
#include "Eigen/Core"
-#include "Eigen/src/Cholesky/LDLT.h"
#include "drake/math/discrete_algebraic_riccati_equation.h"
#include "frc/StateSpaceUtil.h"
#include "frc/system/Discretization.h"
@@ -37,6 +39,10 @@
* For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9
* "Stochastic control theory".
+ *
+ * @tparam States The number of states.
+ * @tparam Inputs The number of inputs.
+ * @tparam Outputs The number of outputs.
*/
template <int States, int Inputs, int Outputs>
class KalmanFilterImpl {
@@ -50,8 +56,8 @@
* @param dt Nominal discretization timestep.
*/
KalmanFilterImpl(LinearSystem<States, Inputs, Outputs>& plant,
- const std::array<double, States>& stateStdDevs,
- const std::array<double, Outputs>& measurementStdDevs,
+ const wpi::array<double, States>& stateStdDevs,
+ const wpi::array<double, Outputs>& measurementStdDevs,
units::second_t dt) {
m_plant = &plant;
@@ -66,34 +72,35 @@
const auto& C = plant.C();
- // IsStabilizable(A^T, C^T) will tell us if the system is observable.
- bool isObservable =
- IsStabilizable<States, Outputs>(discA.transpose(), C.transpose());
- if (!isObservable) {
- wpi::math::MathSharedStore::ReportError(
- "The system passed to the Kalman filter is not observable!");
- throw std::invalid_argument(
- "The system passed to the Kalman filter is not observable!");
+ if (!IsDetectable<States, Outputs>(discA, C)) {
+ std::string msg = fmt::format(
+ "The system passed to the Kalman filter is "
+ "unobservable!\n\nA =\n{}\nC =\n{}\n",
+ discA, C);
+
+ wpi::math::MathSharedStore::ReportError(msg);
+ throw std::invalid_argument(msg);
}
Eigen::Matrix<double, States, States> P =
drake::math::DiscreteAlgebraicRiccatiEquation(
discA.transpose(), C.transpose(), discQ, discR);
+ // S = CPCᵀ + R
Eigen::Matrix<double, Outputs, Outputs> S = C * P * C.transpose() + discR;
- // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
+ // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more
// efficiently.
//
- // K = PC^T S^-1
- // KS = PC^T
- // (KS)^T = (PC^T)^T
- // S^T K^T = CP^T
+ // K = PCᵀS⁻¹
+ // KS = PCᵀ
+ // (KS)ᵀ = (PCᵀ)ᵀ
+ // SᵀKᵀ = CPᵀ
//
// The solution of Ax = b can be found via x = A.solve(b).
//
- // K^T = S^T.solve(CP^T)
- // K = (S^T.solve(CP^T))^T
+ // Kᵀ = Sᵀ.solve(CPᵀ)
+ // K = (Sᵀ.solve(CPᵀ))ᵀ
m_K = S.transpose().ldlt().solve(C * P.transpose()).transpose();
Reset();
@@ -118,7 +125,7 @@
/**
* Returns the state estimate x-hat.
*/
- const Eigen::Matrix<double, States, 1>& Xhat() const { return m_xHat; }
+ const Eigen::Vector<double, States>& Xhat() const { return m_xHat; }
/**
* Returns an element of the state estimate x-hat.
@@ -132,7 +139,7 @@
*
* @param xHat The state estimate x-hat.
*/
- void SetXhat(const Eigen::Matrix<double, States, 1>& xHat) { m_xHat = xHat; }
+ void SetXhat(const Eigen::Vector<double, States>& xHat) { m_xHat = xHat; }
/**
* Set an element of the initial state estimate x-hat.
@@ -153,7 +160,7 @@
* @param u New control input from controller.
* @param dt Timestep for prediction.
*/
- void Predict(const Eigen::Matrix<double, Inputs, 1>& u, units::second_t dt) {
+ void Predict(const Eigen::Vector<double, Inputs>& u, units::second_t dt) {
m_xHat = m_plant->CalculateX(m_xHat, u, dt);
}
@@ -163,8 +170,9 @@
* @param u Same control input used in the last predict step.
* @param y Measurement vector.
*/
- void Correct(const Eigen::Matrix<double, Inputs, 1>& u,
- const Eigen::Matrix<double, Outputs, 1>& y) {
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Outputs>& y) {
+ // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁))
m_xHat += m_K * (y - (m_plant->C() * m_xHat + m_plant->D() * u));
}
@@ -179,7 +187,7 @@
/**
* The state estimate.
*/
- Eigen::Matrix<double, States, 1> m_xHat;
+ Eigen::Vector<double, States> m_xHat;
};
} // namespace detail
@@ -196,8 +204,8 @@
* @param dt Nominal discretization timestep.
*/
KalmanFilter(LinearSystem<States, Inputs, Outputs>& plant,
- const std::array<double, States>& stateStdDevs,
- const std::array<double, Outputs>& measurementStdDevs,
+ const wpi::array<double, States>& stateStdDevs,
+ const wpi::array<double, Outputs>& measurementStdDevs,
units::second_t dt)
: detail::KalmanFilterImpl<States, Inputs, Outputs>{
plant, stateStdDevs, measurementStdDevs, dt} {}
@@ -209,11 +217,12 @@
// Template specializations are used here to make common state-input-output
// triplets compile faster.
template <>
-class KalmanFilter<1, 1, 1> : public detail::KalmanFilterImpl<1, 1, 1> {
+class WPILIB_DLLEXPORT KalmanFilter<1, 1, 1>
+ : public detail::KalmanFilterImpl<1, 1, 1> {
public:
KalmanFilter(LinearSystem<1, 1, 1>& plant,
- const std::array<double, 1>& stateStdDevs,
- const std::array<double, 1>& measurementStdDevs,
+ const wpi::array<double, 1>& stateStdDevs,
+ const wpi::array<double, 1>& measurementStdDevs,
units::second_t dt);
KalmanFilter(KalmanFilter&&) = default;
@@ -223,11 +232,12 @@
// Template specializations are used here to make common state-input-output
// triplets compile faster.
template <>
-class KalmanFilter<2, 1, 1> : public detail::KalmanFilterImpl<2, 1, 1> {
+class WPILIB_DLLEXPORT KalmanFilter<2, 1, 1>
+ : public detail::KalmanFilterImpl<2, 1, 1> {
public:
KalmanFilter(LinearSystem<2, 1, 1>& plant,
- const std::array<double, 2>& stateStdDevs,
- const std::array<double, 1>& measurementStdDevs,
+ const wpi::array<double, 2>& stateStdDevs,
+ const wpi::array<double, 1>& measurementStdDevs,
units::second_t dt);
KalmanFilter(KalmanFilter&&) = default;
diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilterLatencyCompensator.h b/wpimath/src/main/native/include/frc/estimator/KalmanFilterLatencyCompensator.h
new file mode 100644
index 0000000..aabb8ec
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/estimator/KalmanFilterLatencyCompensator.h
@@ -0,0 +1,150 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <algorithm>
+#include <array>
+#include <functional>
+#include <utility>
+#include <vector>
+
+#include "Eigen/Core"
+#include "units/math.h"
+#include "units/time.h"
+
+namespace frc {
+
+template <int States, int Inputs, int Outputs, typename KalmanFilterType>
+class KalmanFilterLatencyCompensator {
+ public:
+ struct ObserverSnapshot {
+ Eigen::Vector<double, States> xHat;
+ Eigen::Matrix<double, States, States> errorCovariances;
+ Eigen::Vector<double, Inputs> inputs;
+ Eigen::Vector<double, Outputs> localMeasurements;
+
+ ObserverSnapshot(const KalmanFilterType& observer,
+ const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Outputs>& localY)
+ : xHat(observer.Xhat()),
+ errorCovariances(observer.P()),
+ inputs(u),
+ localMeasurements(localY) {}
+ };
+
+ /**
+ * Clears the observer snapshot buffer.
+ */
+ void Reset() { m_pastObserverSnapshots.clear(); }
+
+ /**
+ * Add past observer states to the observer snapshots list.
+ *
+ * @param observer The observer.
+ * @param u The input at the timestamp.
+ * @param localY The local output at the timestamp
+ * @param timestamp The timesnap of the state.
+ */
+ void AddObserverState(const KalmanFilterType& observer,
+ Eigen::Vector<double, Inputs> u,
+ Eigen::Vector<double, Outputs> localY,
+ units::second_t timestamp) {
+ // Add the new state into the vector.
+ m_pastObserverSnapshots.emplace_back(timestamp,
+ ObserverSnapshot{observer, u, localY});
+
+ // Remove the oldest snapshot if the vector exceeds our maximum size.
+ if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
+ m_pastObserverSnapshots.erase(m_pastObserverSnapshots.begin());
+ }
+ }
+
+ /**
+ * Add past global measurements (such as from vision)to the estimator.
+ *
+ * @param observer The observer to apply the past global
+ * measurement.
+ * @param nominalDt The nominal timestep.
+ * @param y The measurement.
+ * @param globalMeasurementCorrect The function take calls correct() on the
+ * observer.
+ * @param timestamp The timestamp of the measurement.
+ */
+ template <int Rows>
+ void ApplyPastGlobalMeasurement(
+ KalmanFilterType* observer, units::second_t nominalDt,
+ Eigen::Vector<double, Rows> y,
+ std::function<void(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Rows>& y)>
+ globalMeasurementCorrect,
+ units::second_t timestamp) {
+ if (m_pastObserverSnapshots.size() == 0) {
+ // State map was empty, which means that we got a measurement right at
+ // startup. The only thing we can do is ignore the measurement.
+ return;
+ }
+
+ // We will perform a binary search to find the index of the element in the
+ // vector that has a timestamp that is equal to or greater than the vision
+ // measurement timestamp.
+ auto lowerBoundIter = std::lower_bound(
+ m_pastObserverSnapshots.cbegin(), m_pastObserverSnapshots.cend(),
+ timestamp,
+ [](const auto& entry, const auto& ts) { return entry.first < ts; });
+ int index = std::distance(m_pastObserverSnapshots.cbegin(), lowerBoundIter);
+
+ // High and Low should be the same. The sampled timestamp is greater than or
+ // equal to the vision pose timestamp. We will now find the entry which is
+ // closest in time to the requested timestamp.
+
+ size_t indexOfClosestEntry =
+ units::math::abs(
+ timestamp - m_pastObserverSnapshots[std::max(index - 1, 0)].first) <
+ units::math::abs(timestamp -
+ m_pastObserverSnapshots[index].first)
+ ? index - 1
+ : index;
+
+ units::second_t lastTimestamp =
+ m_pastObserverSnapshots[indexOfClosestEntry].first - nominalDt;
+
+ // We will now go back in time to the state of the system at the time when
+ // the measurement was captured. We will reset the observer to that state,
+ // and apply correction based on the measurement. Then, we will go back
+ // through all observer states until the present and apply past inputs to
+ // get the present estimated state.
+ for (size_t i = indexOfClosestEntry; i < m_pastObserverSnapshots.size();
+ ++i) {
+ auto& [key, snapshot] = m_pastObserverSnapshots[i];
+
+ if (i == indexOfClosestEntry) {
+ observer->SetP(snapshot.errorCovariances);
+ observer->SetXhat(snapshot.xHat);
+ }
+
+ observer->Predict(snapshot.inputs, key - lastTimestamp);
+ observer->Correct(snapshot.inputs, snapshot.localMeasurements);
+
+ if (i == indexOfClosestEntry) {
+ // Note that the measurement is at a timestep close but probably not
+ // exactly equal to the timestep for which we called predict. This makes
+ // the assumption that the dt is small enough that the difference
+ // between the measurement time and the time that the inputs were
+ // captured at is very small.
+ globalMeasurementCorrect(snapshot.inputs, y);
+ }
+
+ lastTimestamp = key;
+ snapshot = ObserverSnapshot{*observer, snapshot.inputs,
+ snapshot.localMeasurements};
+ }
+ }
+
+ private:
+ static constexpr size_t kMaxPastObserverStates = 300;
+ std::vector<std::pair<units::second_t, ObserverSnapshot>>
+ m_pastObserverSnapshots;
+};
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/estimator/MecanumDrivePoseEstimator.h b/wpimath/src/main/native/include/frc/estimator/MecanumDrivePoseEstimator.h
new file mode 100644
index 0000000..93c9f1e
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/estimator/MecanumDrivePoseEstimator.h
@@ -0,0 +1,233 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <functional>
+
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
+
+#include "Eigen/Core"
+#include "frc/estimator/KalmanFilterLatencyCompensator.h"
+#include "frc/estimator/UnscentedKalmanFilter.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/geometry/Rotation2d.h"
+#include "frc/kinematics/MecanumDriveKinematics.h"
+#include "units/time.h"
+
+namespace frc {
+/**
+ * This class wraps an Unscented Kalman Filter to fuse latency-compensated
+ * vision measurements with mecanum drive encoder velocity measurements. It will
+ * correct for noisy measurements and encoder drift. It is intended to be an
+ * easy but more accurate drop-in for MecanumDriveOdometry.
+ *
+ * Update() should be called every robot loop. If your loops are faster or
+ * slower than the default of 0.02s, then you should change the nominal delta
+ * time by specifying it in the constructor.
+ *
+ * AddVisionMeasurement() can be called as infrequently as you want; if you
+ * never call it, then this class will behave mostly like regular encoder
+ * odometry.
+ *
+ * The state-space system used internally has the following states (x), inputs
+ * (u), and outputs (y):
+ *
+ * <strong> x = [x, y, theta]ᵀ </strong> in the field coordinate system
+ * containing x position, y position, and heading.
+ *
+ * <strong> u = [v_l, v_r, dtheta]ᵀ </strong> containing left wheel velocity,
+ * right wheel velocity, and change in gyro heading.
+ *
+ * <strong> y = [x, y, theta]ᵀ </strong> from vision containing x position, y
+ * position, and heading; or <strong> y = [theta]ᵀ </strong> containing gyro
+ * heading.
+ */
+class WPILIB_DLLEXPORT MecanumDrivePoseEstimator {
+ public:
+ /**
+ * Constructs a MecanumDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPose The starting pose estimate.
+ * @param kinematics A correctly-configured kinematics object
+ * for your drivetrain.
+ * @param stateStdDevs Standard deviations of model states.
+ * Increase these numbers to trust your
+ * model's state estimates less. This matrix
+ * is in the form [x, y, theta]ᵀ, with units
+ * in meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro
+ * measurements. Increase these numbers to
+ * trust sensor readings from encoders
+ * and gyros less. This matrix is in the form
+ * [theta], with units in radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ * @param nominalDt The time in seconds between each robot
+ * loop.
+ */
+ MecanumDrivePoseEstimator(
+ const Rotation2d& gyroAngle, const Pose2d& initialPose,
+ MecanumDriveKinematics kinematics,
+ const wpi::array<double, 3>& stateStdDevs,
+ const wpi::array<double, 1>& localMeasurementStdDevs,
+ const wpi::array<double, 3>& visionMeasurementStdDevs,
+ units::second_t nominalDt = 0.02_s);
+
+ /**
+ * Sets the pose estimator's trust of global measurements. This might be used
+ * to change trust in vision measurements after the autonomous period, or to
+ * change trust as distance to a vision target increases.
+ *
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ */
+ void SetVisionMeasurementStdDevs(
+ const wpi::array<double, 3>& visionMeasurementStdDevs);
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * <p>You NEED to reset your encoders (to zero) when calling this method.
+ *
+ * <p>The gyroscope angle does not need to be reset in the user's robot code.
+ * The library automatically takes care of offsetting the gyro angle.
+ *
+ * @param pose The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ void ResetPosition(const Pose2d& pose, const Rotation2d& gyroAngle);
+
+ /**
+ * Gets the pose of the robot at the current time as estimated by the Extended
+ * Kalman Filter.
+ *
+ * @return The estimated robot pose in meters.
+ */
+ Pose2d GetEstimatedPosition() const;
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct
+ * the odometry pose estimate while still accounting for measurement noise.
+ *
+ * This method can be called as infrequently as you want, as long as you are
+ * calling Update() every loop.
+ *
+ * @param visionRobotPose The pose of the robot as measured by the vision
+ * camera.
+ * @param timestamp The timestamp of the vision measurement in seconds.
+ * Note that if you don't use your own time source by
+ * calling UpdateWithTime() then you must use a
+ * timestamp with an epoch since FPGA startup
+ * (i.e. the epoch of this timestamp is the same
+ * epoch as Timer#GetFPGATimestamp.) This means
+ * that you should use Timer#GetFPGATimestamp as your
+ * time source or sync the epochs.
+ */
+ void AddVisionMeasurement(const Pose2d& visionRobotPose,
+ units::second_t timestamp);
+
+ /**
+ * Adds a vision measurement to the Unscented Kalman Filter. This will correct
+ * the odometry pose estimate while still accounting for measurement noise.
+ *
+ * This method can be called as infrequently as you want, as long as you are
+ * calling Update() every loop.
+ *
+ * Note that the vision measurement standard deviations passed into this
+ * method will continue to apply to future measurements until a subsequent
+ * call to SetVisionMeasurementStdDevs() or this method.
+ *
+ * @param visionRobotPose The pose of the robot as measured by the
+ * vision camera.
+ * @param timestamp The timestamp of the vision measurement in
+ * seconds. Note that if you don't use your
+ * own time source by calling
+ * UpdateWithTime(), then you must use a
+ * timestamp with an epoch since FPGA startup
+ * (i.e. the epoch of this timestamp is the
+ * same epoch as
+ * frc::Timer::GetFPGATimestamp(). This means
+ * that you should use
+ * frc::Timer::GetFPGATimestamp() as your
+ * time source in this case.
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ */
+ void AddVisionMeasurement(
+ const Pose2d& visionRobotPose, units::second_t timestamp,
+ const wpi::array<double, 3>& visionMeasurementStdDevs) {
+ SetVisionMeasurementStdDevs(visionMeasurementStdDevs);
+ AddVisionMeasurement(visionRobotPose, timestamp);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder
+ * information. This should be called every loop, and the correct loop period
+ * must be passed into the constructor of this class.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param wheelSpeeds The current speeds of the mecanum drive wheels.
+ * @return The estimated pose of the robot in meters.
+ */
+ Pose2d Update(const Rotation2d& gyroAngle,
+ const MecanumDriveWheelSpeeds& wheelSpeeds);
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder
+ * information. This should be called every loop, and the correct loop period
+ * must be passed into the constructor of this class.
+ *
+ * @param currentTime Time at which this method was called, in seconds.
+ * @param gyroAngle The current gyroscope angle.
+ * @param wheelSpeeds The current speeds of the mecanum drive wheels.
+ * @return The estimated pose of the robot in meters.
+ */
+ Pose2d UpdateWithTime(units::second_t currentTime,
+ const Rotation2d& gyroAngle,
+ const MecanumDriveWheelSpeeds& wheelSpeeds);
+
+ private:
+ UnscentedKalmanFilter<3, 3, 1> m_observer;
+ MecanumDriveKinematics m_kinematics;
+ KalmanFilterLatencyCompensator<3, 3, 1, UnscentedKalmanFilter<3, 3, 1>>
+ m_latencyCompensator;
+ std::function<void(const Eigen::Vector<double, 3>& u,
+ const Eigen::Vector<double, 3>& y)>
+ m_visionCorrect;
+
+ Eigen::Matrix3d m_visionContR;
+
+ units::second_t m_nominalDt;
+ units::second_t m_prevTime = -1_s;
+
+ Rotation2d m_gyroOffset;
+ Rotation2d m_previousAngle;
+
+ template <int Dim>
+ static wpi::array<double, Dim> StdDevMatrixToArray(
+ const Eigen::Vector<double, Dim>& vector) {
+ wpi::array<double, Dim> array;
+ for (size_t i = 0; i < Dim; ++i) {
+ array[i] = vector(i);
+ }
+ return array;
+ }
+};
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h b/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h
index 72abb80..42f5593 100644
--- a/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h
+++ b/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h
@@ -1,16 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <cmath>
+#include "Eigen/Cholesky"
#include "Eigen/Core"
-#include "Eigen/src/Cholesky/LLT.h"
namespace frc {
@@ -22,11 +19,11 @@
* version seen in most publications. Unless you know better, this should be
* your default choice.
*
- * @tparam States The dimensionality of the state. 2*States+1 weights will be
- * generated.
- *
* [1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic
* Inference in Dynamic State-Space Models" (Doctoral dissertation)
+ *
+ * @tparam States The dimensionality of the state. 2*States+1 weights will be
+ * generated.
*/
template <int States>
class MerweScaledSigmaPoints {
@@ -40,8 +37,8 @@
* For Gaussian distributions, beta = 2 is optimal.
* @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
*/
- MerweScaledSigmaPoints(double alpha = 1e-3, double beta = 2,
- int kappa = 3 - States) {
+ explicit MerweScaledSigmaPoints(double alpha = 1e-3, double beta = 2,
+ int kappa = 3 - States) {
m_alpha = alpha;
m_kappa = kappa;
@@ -66,7 +63,7 @@
*
*/
Eigen::Matrix<double, States, 2 * States + 1> SigmaPoints(
- const Eigen::Matrix<double, States, 1>& x,
+ const Eigen::Vector<double, States>& x,
const Eigen::Matrix<double, States, States>& P) {
double lambda = std::pow(m_alpha, 2) * (States + m_kappa) - States;
Eigen::Matrix<double, States, States> U =
@@ -87,7 +84,7 @@
/**
* Returns the weight for each sigma point for the mean.
*/
- const Eigen::Matrix<double, 2 * States + 1, 1>& Wm() const { return m_Wm; }
+ const Eigen::Vector<double, 2 * States + 1>& Wm() const { return m_Wm; }
/**
* Returns an element of the weight for each sigma point for the mean.
@@ -99,7 +96,7 @@
/**
* Returns the weight for each sigma point for the covariance.
*/
- const Eigen::Matrix<double, 2 * States + 1, 1>& Wc() const { return m_Wc; }
+ const Eigen::Vector<double, 2 * States + 1>& Wc() const { return m_Wc; }
/**
* Returns an element of the weight for each sigma point for the covariance.
@@ -109,8 +106,8 @@
double Wc(int i) const { return m_Wc(i, 0); }
private:
- Eigen::Matrix<double, 2 * States + 1, 1> m_Wm;
- Eigen::Matrix<double, 2 * States + 1, 1> m_Wc;
+ Eigen::Vector<double, 2 * States + 1> m_Wm;
+ Eigen::Vector<double, 2 * States + 1> m_Wc;
double m_alpha;
int m_kappa;
@@ -123,8 +120,8 @@
double lambda = std::pow(m_alpha, 2) * (States + m_kappa) - States;
double c = 0.5 / (States + lambda);
- m_Wm = Eigen::Matrix<double, 2 * States + 1, 1>::Constant(c);
- m_Wc = Eigen::Matrix<double, 2 * States + 1, 1>::Constant(c);
+ m_Wm = Eigen::Vector<double, 2 * States + 1>::Constant(c);
+ m_Wc = Eigen::Vector<double, 2 * States + 1>::Constant(c);
m_Wm(0) = lambda / (States + lambda);
m_Wc(0) = lambda / (States + lambda) + (1 - std::pow(m_alpha, 2) + beta);
diff --git a/wpimath/src/main/native/include/frc/estimator/SwerveDrivePoseEstimator.h b/wpimath/src/main/native/include/frc/estimator/SwerveDrivePoseEstimator.h
new file mode 100644
index 0000000..91e0e50
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/estimator/SwerveDrivePoseEstimator.h
@@ -0,0 +1,317 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <limits>
+
+#include <wpi/array.h>
+#include <wpi/timestamp.h>
+
+#include "Eigen/Core"
+#include "frc/StateSpaceUtil.h"
+#include "frc/estimator/AngleStatistics.h"
+#include "frc/estimator/KalmanFilterLatencyCompensator.h"
+#include "frc/estimator/UnscentedKalmanFilter.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/geometry/Rotation2d.h"
+#include "frc/kinematics/SwerveDriveKinematics.h"
+#include "units/time.h"
+
+namespace frc {
+/**
+ * This class wraps an Unscented Kalman Filter to fuse latency-compensated
+ * vision measurements with swerve drive encoder velocity measurements. It will
+ * correct for noisy measurements and encoder drift. It is intended to be an
+ * easy but more accurate drop-in for SwerveDriveOdometry.
+ *
+ * Update() should be called every robot loop. If your loops are faster or
+ * slower than the default of 0.02s, then you should change the nominal delta
+ * time by specifying it in the constructor.
+ *
+ * AddVisionMeasurement() can be called as infrequently as you want; if you
+ * never call it, then this class will behave mostly like regular encoder
+ * odometry.
+ *
+ * The state-space system used internally has the following states (x), inputs
+ * (u), and outputs (y):
+ *
+ * <strong> x = [x, y, theta]ᵀ </strong> in the field coordinate system
+ * containing x position, y position, and heading.
+ *
+ * <strong> u = [v_l, v_r, dtheta]ᵀ </strong> containing left wheel velocity,
+ * right wheel velocity, and change in gyro heading.
+ *
+ * <strong> y = [x, y, theta]ᵀ </strong> from vision containing x position, y
+ * position, and heading; or <strong> y = [theta]ᵀ </strong> containing gyro
+ * heading.
+ */
+template <size_t NumModules>
+class SwerveDrivePoseEstimator {
+ public:
+ /**
+ * Constructs a SwerveDrivePoseEstimator.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param initialPose The starting pose estimate.
+ * @param kinematics A correctly-configured kinematics object
+ * for your drivetrain.
+ * @param stateStdDevs Standard deviations of model states.
+ * Increase these numbers to trust your
+ * model's state estimates less. This matrix
+ * is in the form [x, y, theta]ᵀ, with units
+ * in meters and radians.
+ * @param localMeasurementStdDevs Standard deviations of the encoder and gyro
+ * measurements. Increase these numbers to
+ * trust sensor readings from encoders
+ * and gyros less. This matrix is in the form
+ * [theta], with units in radians.
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ * @param nominalDt The time in seconds between each robot
+ * loop.
+ */
+ SwerveDrivePoseEstimator(
+ const Rotation2d& gyroAngle, const Pose2d& initialPose,
+ SwerveDriveKinematics<NumModules>& kinematics,
+ const wpi::array<double, 3>& stateStdDevs,
+ const wpi::array<double, 1>& localMeasurementStdDevs,
+ const wpi::array<double, 3>& visionMeasurementStdDevs,
+ units::second_t nominalDt = 0.02_s)
+ : m_observer([](const Eigen::Vector<double, 3>& x,
+ const Eigen::Vector<double, 3>& u) { return u; },
+ [](const Eigen::Vector<double, 3>& x,
+ const Eigen::Vector<double, 3>& u) {
+ return x.block<1, 1>(2, 0);
+ },
+ stateStdDevs, localMeasurementStdDevs,
+ frc::AngleMean<3, 3>(2), frc::AngleMean<1, 3>(0),
+ frc::AngleResidual<3>(2), frc::AngleResidual<1>(0),
+ frc::AngleAdd<3>(2), nominalDt),
+ m_kinematics(kinematics),
+ m_nominalDt(nominalDt) {
+ SetVisionMeasurementStdDevs(visionMeasurementStdDevs);
+
+ // Create correction mechanism for vision measurements.
+ m_visionCorrect = [&](const Eigen::Vector<double, 3>& u,
+ const Eigen::Vector<double, 3>& y) {
+ m_observer.Correct<3>(
+ u, y,
+ [](const Eigen::Vector<double, 3>& x,
+ const Eigen::Vector<double, 3>& u) { return x; },
+ m_visionContR, frc::AngleMean<3, 3>(2), frc::AngleResidual<3>(2),
+ frc::AngleResidual<3>(2), frc::AngleAdd<3>(2));
+ };
+
+ // Set initial state.
+ m_observer.SetXhat(PoseTo3dVector(initialPose));
+
+ // Calculate offsets.
+ m_gyroOffset = initialPose.Rotation() - gyroAngle;
+ m_previousAngle = initialPose.Rotation();
+ }
+
+ /**
+ * Resets the robot's position on the field.
+ *
+ * You NEED to reset your encoders (to zero) when calling this method.
+ *
+ * The gyroscope angle does not need to be reset in the user's robot code.
+ * The library automatically takes care of offsetting the gyro angle.
+ *
+ * @param pose The position on the field that your robot is at.
+ * @param gyroAngle The angle reported by the gyroscope.
+ */
+ void ResetPosition(const Pose2d& pose, const Rotation2d& gyroAngle) {
+ // Reset state estimate and error covariance
+ m_observer.Reset();
+ m_latencyCompensator.Reset();
+
+ m_observer.SetXhat(PoseTo3dVector(pose));
+
+ m_gyroOffset = pose.Rotation() - gyroAngle;
+ m_previousAngle = pose.Rotation();
+ }
+
+ /**
+ * Gets the pose of the robot at the current time as estimated by the Extended
+ * Kalman Filter.
+ *
+ * @return The estimated robot pose in meters.
+ */
+ Pose2d GetEstimatedPosition() const {
+ return Pose2d(m_observer.Xhat(0) * 1_m, m_observer.Xhat(1) * 1_m,
+ Rotation2d(units::radian_t{m_observer.Xhat(2)}));
+ }
+
+ /**
+ * Sets the pose estimator's trust of global measurements. This might be used
+ * to change trust in vision measurements after the autonomous period, or to
+ * change trust as distance to a vision target increases.
+ *
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ */
+ void SetVisionMeasurementStdDevs(
+ const wpi::array<double, 3>& visionMeasurementStdDevs) {
+ // Create R (covariances) for vision measurements.
+ m_visionContR = frc::MakeCovMatrix(visionMeasurementStdDevs);
+ }
+
+ /**
+ * Add a vision measurement to the Unscented Kalman Filter. This will correct
+ * the odometry pose estimate while still accounting for measurement noise.
+ *
+ * This method can be called as infrequently as you want, as long as you are
+ * calling Update() every loop.
+ *
+ * @param visionRobotPose The pose of the robot as measured by the vision
+ * camera.
+ * @param timestamp The timestamp of the vision measurement in seconds.
+ * Note that if you don't use your own time source by
+ * calling UpdateWithTime() then you must use a
+ * timestamp with an epoch since FPGA startup
+ * (i.e. the epoch of this timestamp is the same
+ * epoch as Timer#GetFPGATimestamp.) This means
+ * that you should use Timer#GetFPGATimestamp as your
+ * time source or sync the epochs.
+ */
+ void AddVisionMeasurement(const Pose2d& visionRobotPose,
+ units::second_t timestamp) {
+ m_latencyCompensator.ApplyPastGlobalMeasurement<3>(
+ &m_observer, m_nominalDt, PoseTo3dVector(visionRobotPose),
+ m_visionCorrect, timestamp);
+ }
+
+ /**
+ * Adds a vision measurement to the Unscented Kalman Filter. This will correct
+ * the odometry pose estimate while still accounting for measurement noise.
+ *
+ * This method can be called as infrequently as you want, as long as you are
+ * calling Update() every loop.
+ *
+ * Note that the vision measurement standard deviations passed into this
+ * method will continue to apply to future measurements until a subsequent
+ * call to SetVisionMeasurementStdDevs() or this method.
+ *
+ * @param visionRobotPose The pose of the robot as measured by the
+ * vision camera.
+ * @param timestamp The timestamp of the vision measurement in
+ * seconds. Note that if you don't use your
+ * own time source by calling
+ * UpdateWithTime(), then you must use a
+ * timestamp with an epoch since FPGA startup
+ * (i.e. the epoch of this timestamp is the
+ * same epoch as
+ * frc::Timer::GetFPGATimestamp(). This means
+ * that you should use
+ * frc::Timer::GetFPGATimestamp() as your
+ * time source in this case.
+ * @param visionMeasurementStdDevs Standard deviations of the vision
+ * measurements. Increase these numbers to
+ * trust global measurements from vision
+ * less. This matrix is in the form
+ * [x, y, theta]ᵀ, with units in meters and
+ * radians.
+ */
+ void AddVisionMeasurement(
+ const Pose2d& visionRobotPose, units::second_t timestamp,
+ const wpi::array<double, 3>& visionMeasurementStdDevs) {
+ SetVisionMeasurementStdDevs(visionMeasurementStdDevs);
+ AddVisionMeasurement(visionRobotPose, timestamp);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder
+ * information. This should be called every loop, and the correct loop period
+ * must be passed into the constructor of this class.
+ *
+ * @param gyroAngle The current gyro angle.
+ * @param moduleStates The current velocities and rotations of the swerve
+ * modules.
+ * @return The estimated pose of the robot in meters.
+ */
+ template <typename... ModuleState>
+ Pose2d Update(const Rotation2d& gyroAngle, ModuleState&&... moduleStates) {
+ return UpdateWithTime(units::microsecond_t(wpi::Now()), gyroAngle,
+ moduleStates...);
+ }
+
+ /**
+ * Updates the the Unscented Kalman Filter using only wheel encoder
+ * information. This should be called every loop, and the correct loop period
+ * must be passed into the constructor of this class.
+ *
+ * @param currentTime Time at which this method was called, in seconds.
+ * @param gyroAngle The current gyroscope angle.
+ * @param moduleStates The current velocities and rotations of the swerve
+ * modules.
+ * @return The estimated pose of the robot in meters.
+ */
+ template <typename... ModuleState>
+ Pose2d UpdateWithTime(units::second_t currentTime,
+ const Rotation2d& gyroAngle,
+ ModuleState&&... moduleStates) {
+ auto dt = m_prevTime >= 0_s ? currentTime - m_prevTime : m_nominalDt;
+ m_prevTime = currentTime;
+
+ auto angle = gyroAngle + m_gyroOffset;
+ auto omega = (angle - m_previousAngle).Radians() / dt;
+
+ auto chassisSpeeds = m_kinematics.ToChassisSpeeds(moduleStates...);
+ auto fieldRelativeSpeeds =
+ Translation2d(chassisSpeeds.vx * 1_s, chassisSpeeds.vy * 1_s)
+ .RotateBy(angle);
+
+ Eigen::Vector<double, 3> u{fieldRelativeSpeeds.X().value(),
+ fieldRelativeSpeeds.Y().value(), omega.value()};
+
+ Eigen::Vector<double, 1> localY{angle.Radians().value()};
+ m_previousAngle = angle;
+
+ m_latencyCompensator.AddObserverState(m_observer, u, localY, currentTime);
+
+ m_observer.Predict(u, dt);
+ m_observer.Correct(u, localY);
+
+ return GetEstimatedPosition();
+ }
+
+ private:
+ UnscentedKalmanFilter<3, 3, 1> m_observer;
+ SwerveDriveKinematics<NumModules>& m_kinematics;
+ KalmanFilterLatencyCompensator<3, 3, 1, UnscentedKalmanFilter<3, 3, 1>>
+ m_latencyCompensator;
+ std::function<void(const Eigen::Vector<double, 3>& u,
+ const Eigen::Vector<double, 3>& y)>
+ m_visionCorrect;
+
+ Eigen::Matrix3d m_visionContR;
+
+ units::second_t m_nominalDt;
+ units::second_t m_prevTime = -1_s;
+
+ Rotation2d m_gyroOffset;
+ Rotation2d m_previousAngle;
+
+ template <int Dim>
+ static wpi::array<double, Dim> StdDevMatrixToArray(
+ const Eigen::Vector<double, Dim>& vector) {
+ wpi::array<double, Dim> array;
+ for (size_t i = 0; i < Dim; ++i) {
+ array[i] = vector(i);
+ }
+ return array;
+ }
+};
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h
index 8c2c31f..3aa3e59 100644
--- a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h
+++ b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h
@@ -1,27 +1,49 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
#include <functional>
+#include <wpi/array.h>
+
+#include "Eigen/Cholesky"
#include "Eigen/Core"
-#include "Eigen/src/Cholesky/LDLT.h"
#include "frc/StateSpaceUtil.h"
#include "frc/estimator/MerweScaledSigmaPoints.h"
#include "frc/estimator/UnscentedTransform.h"
#include "frc/system/Discretization.h"
+#include "frc/system/NumericalIntegration.h"
#include "frc/system/NumericalJacobian.h"
-#include "frc/system/RungeKutta.h"
#include "units/time.h"
namespace frc {
+/**
+ * A Kalman filter combines predictions from a model and measurements to give an
+ * estimate of the true system state. This is useful because many states cannot
+ * be measured directly as a result of sensor noise, or because the state is
+ * "hidden".
+ *
+ * Kalman filters use a K gain matrix to determine whether to trust the model or
+ * measurements more. Kalman filter theory uses statistics to compute an optimal
+ * K gain which minimizes the sum of squares error in the state estimate. This K
+ * gain is used to correct the state estimate by some amount of the difference
+ * between the actual measurements and the measurements predicted by the model.
+ *
+ * An unscented Kalman filter uses nonlinear state and measurement models. It
+ * propagates the error covariance using sigma points chosen to approximate the
+ * true probability distribution.
+ *
+ * For more on the underlying math, read
+ * https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9
+ * "Stochastic control theory".
+ *
+ * @tparam States The number of states.
+ * @tparam Inputs The number of inputs.
+ * @tparam Outputs The number of outputs.
+ */
template <int States, int Inputs, int Outputs>
class UnscentedKalmanFilter {
public:
@@ -36,20 +58,105 @@
* @param measurementStdDevs Standard deviations of measurements.
* @param dt Nominal discretization timestep.
*/
- UnscentedKalmanFilter(std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ UnscentedKalmanFilter(std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
f,
- std::function<Eigen::Matrix<double, Outputs, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
h,
- const std::array<double, States>& stateStdDevs,
- const std::array<double, Outputs>& measurementStdDevs,
+ const wpi::array<double, States>& stateStdDevs,
+ const wpi::array<double, Outputs>& measurementStdDevs,
units::second_t dt)
: m_f(f), m_h(h) {
m_contQ = MakeCovMatrix(stateStdDevs);
m_contR = MakeCovMatrix(measurementStdDevs);
+ m_meanFuncX = [](auto sigmas, auto Wm) -> Eigen::Vector<double, States> {
+ return sigmas * Wm;
+ };
+ m_meanFuncY = [](auto sigmas, auto Wc) -> Eigen::Vector<double, Outputs> {
+ return sigmas * Wc;
+ };
+ m_residualFuncX = [](auto a, auto b) -> Eigen::Vector<double, States> {
+ return a - b;
+ };
+ m_residualFuncY = [](auto a, auto b) -> Eigen::Vector<double, Outputs> {
+ return a - b;
+ };
+ m_addFuncX = [](auto a, auto b) -> Eigen::Vector<double, States> {
+ return a + b;
+ };
+ m_dt = dt;
+
+ Reset();
+ }
+
+ /**
+ * Constructs an unscented Kalman filter with custom mean, residual, and
+ * addition functions. Using custom functions for arithmetic can be useful if
+ * you have angles in the state or measurements, because they allow you to
+ * correctly account for the modular nature of angle arithmetic.
+ *
+ * @param f A vector-valued function of x and u that returns
+ * the derivative of the state vector.
+ * @param h A vector-valued function of x and u that returns
+ * the measurement vector.
+ * @param stateStdDevs Standard deviations of model states.
+ * @param measurementStdDevs Standard deviations of measurements.
+ * @param meanFuncX A function that computes the mean of 2 * States +
+ * 1 state vectors using a given set of weights.
+ * @param meanFuncY A function that computes the mean of 2 * States +
+ * 1 measurement vectors using a given set of
+ * weights.
+ * @param residualFuncX A function that computes the residual of two
+ * state vectors (i.e. it subtracts them.)
+ * @param residualFuncY A function that computes the residual of two
+ * measurement vectors (i.e. it subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ * @param dt Nominal discretization timestep.
+ */
+ UnscentedKalmanFilter(
+ std::function<
+ Eigen::Vector<double, States>(const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
+ f,
+ std::function<
+ Eigen::Vector<double, Outputs>(const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
+ h,
+ const wpi::array<double, States>& stateStdDevs,
+ const wpi::array<double, Outputs>& measurementStdDevs,
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Matrix<double, States, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+ meanFuncX,
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Matrix<double, Outputs, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+ meanFuncY,
+ std::function<
+ Eigen::Vector<double, States>(const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>&)>
+ residualFuncX,
+ std::function<
+ Eigen::Vector<double, Outputs>(const Eigen::Vector<double, Outputs>&,
+ const Eigen::Vector<double, Outputs>&)>
+ residualFuncY,
+ std::function<
+ Eigen::Vector<double, States>(const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>&)>
+ addFuncX,
+ units::second_t dt)
+ : m_f(f),
+ m_h(h),
+ m_meanFuncX(meanFuncX),
+ m_meanFuncY(meanFuncY),
+ m_residualFuncX(residualFuncX),
+ m_residualFuncY(residualFuncY),
+ m_addFuncX(addFuncX) {
+ m_contQ = MakeCovMatrix(stateStdDevs);
+ m_contR = MakeCovMatrix(measurementStdDevs);
m_dt = dt;
Reset();
@@ -78,7 +185,7 @@
/**
* Returns the state estimate x-hat.
*/
- const Eigen::Matrix<double, States, 1>& Xhat() const { return m_xHat; }
+ const Eigen::Vector<double, States>& Xhat() const { return m_xHat; }
/**
* Returns an element of the state estimate x-hat.
@@ -92,7 +199,7 @@
*
* @param xHat The state estimate x-hat.
*/
- void SetXhat(const Eigen::Matrix<double, States, 1>& xHat) { m_xHat = xHat; }
+ void SetXhat(const Eigen::Vector<double, States>& xHat) { m_xHat = xHat; }
/**
* Set an element of the initial state estimate x-hat.
@@ -117,7 +224,7 @@
* @param u New control input from controller.
* @param dt Timestep for prediction.
*/
- void Predict(const Eigen::Matrix<double, Inputs, 1>& u, units::second_t dt) {
+ void Predict(const Eigen::Vector<double, Inputs>& u, units::second_t dt) {
m_dt = dt;
// Discretize Q before projecting mean and covariance forward
@@ -131,13 +238,12 @@
m_pts.SigmaPoints(m_xHat, m_P);
for (int i = 0; i < m_pts.NumSigmas(); ++i) {
- Eigen::Matrix<double, States, 1> x =
- sigmas.template block<States, 1>(0, i);
- m_sigmasF.template block<States, 1>(0, i) = RungeKutta(m_f, x, u, dt);
+ Eigen::Vector<double, States> x = sigmas.template block<States, 1>(0, i);
+ m_sigmasF.template block<States, 1>(0, i) = RK4(m_f, x, u, dt);
}
- auto ret =
- UnscentedTransform<States, States>(m_sigmasF, m_pts.Wm(), m_pts.Wc());
+ auto ret = UnscentedTransform<States, States>(
+ m_sigmasF, m_pts.Wm(), m_pts.Wc(), m_meanFuncX, m_residualFuncX);
m_xHat = std::get<0>(ret);
m_P = std::get<1>(ret);
@@ -150,9 +256,10 @@
* @param u Same control input used in the predict step.
* @param y Measurement vector.
*/
- void Correct(const Eigen::Matrix<double, Inputs, 1>& u,
- const Eigen::Matrix<double, Outputs, 1>& y) {
- Correct<Outputs>(u, y, m_h, m_contR);
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Outputs>& y) {
+ Correct<Outputs>(u, y, m_h, m_contR, m_meanFuncY, m_residualFuncY,
+ m_residualFuncX, m_addFuncX);
}
/**
@@ -164,18 +271,78 @@
*
* @param u Same control input used in the predict step.
* @param y Measurement vector.
- * @param h A vector-valued function of x and u that returns
- * the measurement vector.
- * @param R Measurement noise covariance matrix.
+ * @param h A vector-valued function of x and u that returns the measurement
+ * vector.
+ * @param R Measurement noise covariance matrix (continuous-time).
*/
template <int Rows>
- void Correct(const Eigen::Matrix<double, Inputs, 1>& u,
- const Eigen::Matrix<double, Rows, 1>& y,
- std::function<Eigen::Matrix<double, Rows, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Rows>& y,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
h,
const Eigen::Matrix<double, Rows, Rows>& R) {
+ auto meanFuncY = [](auto sigmas, auto Wc) -> Eigen::Vector<double, Rows> {
+ return sigmas * Wc;
+ };
+ auto residualFuncX = [](auto a, auto b) -> Eigen::Vector<double, States> {
+ return a - b;
+ };
+ auto residualFuncY = [](auto a, auto b) -> Eigen::Vector<double, Rows> {
+ return a - b;
+ };
+ auto addFuncX = [](auto a, auto b) -> Eigen::Vector<double, States> {
+ return a + b;
+ };
+ Correct<Rows>(u, y, h, R, meanFuncY, residualFuncY, residualFuncX,
+ addFuncX);
+ }
+
+ /**
+ * Correct the state estimate x-hat using the measurements in y.
+ *
+ * This is useful for when the measurements available during a timestep's
+ * Correct() call vary. The h(x, u) passed to the constructor is used if one
+ * is not provided (the two-argument version of this function).
+ *
+ * @param u Same control input used in the predict step.
+ * @param y Measurement vector.
+ * @param h A vector-valued function of x and u that returns the
+ * measurement vector.
+ * @param R Measurement noise covariance matrix (continuous-time).
+ * @param meanFuncY A function that computes the mean of 2 * States + 1
+ * measurement vectors using a given set of weights.
+ * @param residualFuncY A function that computes the residual of two
+ * measurement vectors (i.e. it subtracts them.)
+ * @param residualFuncX A function that computes the residual of two state
+ * vectors (i.e. it subtracts them.)
+ * @param addFuncX A function that adds two state vectors.
+ */
+ template <int Rows>
+ void Correct(const Eigen::Vector<double, Inputs>& u,
+ const Eigen::Vector<double, Rows>& y,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
+ h,
+ const Eigen::Matrix<double, Rows, Rows>& R,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Matrix<double, Rows, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+ meanFuncY,
+ std::function<Eigen::Vector<double, Rows>(
+ const Eigen::Vector<double, Rows>&,
+ const Eigen::Vector<double, Rows>&)>
+ residualFuncY,
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>&)>
+ residualFuncX,
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>)>
+ addFuncX) {
const Eigen::Matrix<double, Rows, Rows> discR = DiscretizeR<Rows>(R, m_dt);
// Transform sigma points into measurement space
@@ -188,41 +355,67 @@
}
// Mean and covariance of prediction passed through UT
- auto [yHat, Py] =
- UnscentedTransform<States, Rows>(sigmasH, m_pts.Wm(), m_pts.Wc());
+ auto [yHat, Py] = UnscentedTransform<Rows, States>(
+ sigmasH, m_pts.Wm(), m_pts.Wc(), meanFuncY, residualFuncY);
Py += discR;
// Compute cross covariance of the state and the measurements
Eigen::Matrix<double, States, Rows> Pxy;
Pxy.setZero();
for (int i = 0; i < m_pts.NumSigmas(); ++i) {
- Pxy += m_pts.Wc(i) *
- (m_sigmasF.template block<States, 1>(0, i) - m_xHat) *
- (sigmasH.template block<Rows, 1>(0, i) - yHat).transpose();
+ // Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
+ Pxy +=
+ m_pts.Wc(i) *
+ (residualFuncX(m_sigmasF.template block<States, 1>(0, i), m_xHat)) *
+ (residualFuncY(sigmasH.template block<Rows, 1>(0, i), yHat))
+ .transpose();
}
- // K = P_{xy} Py^-1
- // K^T = P_y^T^-1 P_{xy}^T
- // P_y^T K^T = P_{xy}^T
- // K^T = P_y^T.solve(P_{xy}^T)
- // K = (P_y^T.solve(P_{xy}^T)^T
+ // K = P_{xy} P_y⁻¹
+ // Kᵀ = P_yᵀ⁻¹ P_{xy}ᵀ
+ // P_yᵀKᵀ = P_{xy}ᵀ
+ // Kᵀ = P_yᵀ.solve(P_{xy}ᵀ)
+ // K = (P_yᵀ.solve(P_{xy}ᵀ)ᵀ
Eigen::Matrix<double, States, Rows> K =
Py.transpose().ldlt().solve(Pxy.transpose()).transpose();
- m_xHat += K * (y - yHat);
+ // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
+ m_xHat = addFuncX(m_xHat, K * residualFuncY(y, yHat));
+
+ // Pₖ₊₁⁺ = Pₖ₊₁⁻ − KP_yKᵀ
m_P -= K * Py * K.transpose();
}
private:
- std::function<Eigen::Matrix<double, States, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
m_f;
- std::function<Eigen::Matrix<double, Outputs, 1>(
- const Eigen::Matrix<double, States, 1>&,
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, Inputs>&)>
m_h;
- Eigen::Matrix<double, States, 1> m_xHat;
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Matrix<double, States, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+ m_meanFuncX;
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Matrix<double, Outputs, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+ m_meanFuncY;
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>&)>
+ m_residualFuncX;
+ std::function<Eigen::Vector<double, Outputs>(
+ const Eigen::Vector<double, Outputs>&,
+ const Eigen::Vector<double, Outputs>)>
+ m_residualFuncY;
+ std::function<Eigen::Vector<double, States>(
+ const Eigen::Vector<double, States>&,
+ const Eigen::Vector<double, States>)>
+ m_addFuncX;
+ Eigen::Vector<double, States> m_xHat;
Eigen::Matrix<double, States, States> m_P;
Eigen::Matrix<double, States, States> m_contQ;
Eigen::Matrix<double, Outputs, Outputs> m_contR;
diff --git a/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h b/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h
index 22b32ce..2df0a47 100644
--- a/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h
+++ b/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -14,37 +11,51 @@
namespace frc {
/**
- * Computes unscented transform of a set of sigma points and weights. CovDimurns
- * the mean and covariance in a tuple.
+ * Computes unscented transform of a set of sigma points and weights. CovDim
+ * returns the mean and covariance in a tuple.
*
* This works in conjunction with the UnscentedKalmanFilter class.
*
- * @tparam States Number of states.
- * @tparam CovDim Dimension of covariance of sigma points after passing through
- * the transform.
- * @param sigmas List of sigma points.
- * @param Wm Weights for the mean.
- * @param Wc Weights for the covariance.
+ * @tparam CovDim Dimension of covariance of sigma points after passing
+ * through the transform.
+ * @tparam States Number of states.
+ * @param sigmas List of sigma points.
+ * @param Wm Weights for the mean.
+ * @param Wc Weights for the covariance.
+ * @param meanFunc A function that computes the mean of 2 * States + 1 state
+ * vectors using a given set of weights.
+ * @param residualFunc A function that computes the residual of two state
+ * vectors (i.e. it subtracts them.)
*
* @return Tuple of x, mean of sigma points; P, covariance of sigma points after
* passing through the transform.
*/
-template <int States, int CovDim>
-std::tuple<Eigen::Matrix<double, CovDim, 1>,
- Eigen::Matrix<double, CovDim, CovDim>>
+template <int CovDim, int States>
+std::tuple<Eigen::Vector<double, CovDim>, Eigen::Matrix<double, CovDim, CovDim>>
UnscentedTransform(const Eigen::Matrix<double, CovDim, 2 * States + 1>& sigmas,
- const Eigen::Matrix<double, 2 * States + 1, 1>& Wm,
- const Eigen::Matrix<double, 2 * States + 1, 1>& Wc) {
- // New mean is just the sum of the sigmas * weight
- // dot = \Sigma^n_1 (W[k]*Xi[k])
- Eigen::Matrix<double, CovDim, 1> x = sigmas * Wm;
+ const Eigen::Vector<double, 2 * States + 1>& Wm,
+ const Eigen::Vector<double, 2 * States + 1>& Wc,
+ std::function<Eigen::Vector<double, CovDim>(
+ const Eigen::Matrix<double, CovDim, 2 * States + 1>&,
+ const Eigen::Vector<double, 2 * States + 1>&)>
+ meanFunc,
+ std::function<Eigen::Vector<double, CovDim>(
+ const Eigen::Vector<double, CovDim>&,
+ const Eigen::Vector<double, CovDim>&)>
+ residualFunc) {
+ // New mean is usually just the sum of the sigmas * weight:
+ // n
+ // dot = Σ W[k] Xᵢ[k]
+ // k=1
+ Eigen::Vector<double, CovDim> x = meanFunc(sigmas, Wm);
// New covariance is the sum of the outer product of the residuals times the
// weights
Eigen::Matrix<double, CovDim, 2 * States + 1> y;
for (int i = 0; i < 2 * States + 1; ++i) {
+ // y[:, i] = sigmas[:, i] - x
y.template block<CovDim, 1>(0, i) =
- sigmas.template block<CovDim, 1>(0, i) - x;
+ residualFunc(sigmas.template block<CovDim, 1>(0, i), x);
}
Eigen::Matrix<double, CovDim, CovDim> P =
y * Eigen::DiagonalMatrix<double, 2 * States + 1>(Wc) * y.transpose();
diff --git a/wpimath/src/main/native/include/frc/filter/LinearFilter.h b/wpimath/src/main/native/include/frc/filter/LinearFilter.h
new file mode 100644
index 0000000..92d8bdc
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/filter/LinearFilter.h
@@ -0,0 +1,298 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <initializer_list>
+#include <stdexcept>
+#include <vector>
+
+#include <wpi/circular_buffer.h>
+#include <wpi/span.h>
+
+#include "Eigen/Core"
+#include "Eigen/QR"
+#include "units/time.h"
+#include "wpimath/MathShared.h"
+
+namespace frc {
+
+/**
+ * This class implements a linear, digital filter. All types of FIR and IIR
+ * filters are supported. Static factory methods are provided to create commonly
+ * used types of filters.
+ *
+ * Filters are of the form:<br>
+ * y[n] = (b0 x[n] + b1 x[n-1] + … + bP x[n-P]) -
+ * (a0 y[n-1] + a2 y[n-2] + … + aQ y[n-Q])
+ *
+ * Where:<br>
+ * y[n] is the output at time "n"<br>
+ * x[n] is the input at time "n"<br>
+ * y[n-1] is the output from the LAST time step ("n-1")<br>
+ * x[n-1] is the input from the LAST time step ("n-1")<br>
+ * b0 … bP are the "feedforward" (FIR) gains<br>
+ * a0 … aQ are the "feedback" (IIR) gains<br>
+ * IMPORTANT! Note the "-" sign in front of the feedback term! This is a common
+ * convention in signal processing.
+ *
+ * What can linear filters do? Basically, they can filter, or diminish, the
+ * effects of undesirable input frequencies. High frequencies, or rapid changes,
+ * can be indicative of sensor noise or be otherwise undesirable. A "low pass"
+ * filter smooths out the signal, reducing the impact of these high frequency
+ * components. Likewise, a "high pass" filter gets rid of slow-moving signal
+ * components, letting you detect large changes more easily.
+ *
+ * Example FRC applications of filters:
+ * - Getting rid of noise from an analog sensor input (note: the roboRIO's FPGA
+ * can do this faster in hardware)
+ * - Smoothing out joystick input to prevent the wheels from slipping or the
+ * robot from tipping
+ * - Smoothing motor commands so that unnecessary strain isn't put on
+ * electrical or mechanical components
+ * - If you use clever gains, you can make a PID controller out of this class!
+ *
+ * For more on filters, we highly recommend the following articles:<br>
+ * https://en.wikipedia.org/wiki/Linear_filter<br>
+ * https://en.wikipedia.org/wiki/Iir_filter<br>
+ * https://en.wikipedia.org/wiki/Fir_filter<br>
+ *
+ * Note 1: Calculate() should be called by the user on a known, regular period.
+ * You can use a Notifier for this or do it "inline" with code in a
+ * periodic function.
+ *
+ * Note 2: For ALL filters, gains are necessarily a function of frequency. If
+ * you make a filter that works well for you at, say, 100Hz, you will most
+ * definitely need to adjust the gains if you then want to run it at 200Hz!
+ * Combining this with Note 1 - the impetus is on YOU as a developer to make
+ * sure Calculate() gets called at the desired, constant frequency!
+ */
+template <class T>
+class LinearFilter {
+ public:
+ /**
+ * Create a linear FIR or IIR filter.
+ *
+ * @param ffGains The "feedforward" or FIR gains.
+ * @param fbGains The "feedback" or IIR gains.
+ */
+ LinearFilter(wpi::span<const double> ffGains, wpi::span<const double> fbGains)
+ : m_inputs(ffGains.size()),
+ m_outputs(fbGains.size()),
+ m_inputGains(ffGains.begin(), ffGains.end()),
+ m_outputGains(fbGains.begin(), fbGains.end()) {
+ for (size_t i = 0; i < ffGains.size(); ++i) {
+ m_inputs.emplace_front(0.0);
+ }
+ for (size_t i = 0; i < fbGains.size(); ++i) {
+ m_outputs.emplace_front(0.0);
+ }
+
+ static int instances = 0;
+ instances++;
+ wpi::math::MathSharedStore::ReportUsage(
+ wpi::math::MathUsageId::kFilter_Linear, 1);
+ }
+
+ /**
+ * Create a linear FIR or IIR filter.
+ *
+ * @param ffGains The "feedforward" or FIR gains.
+ * @param fbGains The "feedback" or IIR gains.
+ */
+ LinearFilter(std::initializer_list<double> ffGains,
+ std::initializer_list<double> fbGains)
+ : LinearFilter({ffGains.begin(), ffGains.end()},
+ {fbGains.begin(), fbGains.end()}) {}
+
+ // Static methods to create commonly used filters
+ /**
+ * Creates a one-pole IIR low-pass filter of the form:<br>
+ * y[n] = (1 - gain) x[n] + gain y[n-1]<br>
+ * where gain = e<sup>-dt / T</sup>, T is the time constant in seconds
+ *
+ * Note: T = 1 / (2 pi f) where f is the cutoff frequency in Hz, the frequency
+ * above which the input starts to attenuate.
+ *
+ * This filter is stable for time constants greater than zero.
+ *
+ * @param timeConstant The discrete-time time constant in seconds.
+ * @param period The period in seconds between samples taken by the
+ * user.
+ */
+ static LinearFilter<T> SinglePoleIIR(double timeConstant,
+ units::second_t period) {
+ double gain = std::exp(-period.value() / timeConstant);
+ return LinearFilter({1.0 - gain}, {-gain});
+ }
+
+ /**
+ * Creates a first-order high-pass filter of the form:<br>
+ * y[n] = gain x[n] + (-gain) x[n-1] + gain y[n-1]<br>
+ * where gain = e<sup>-dt / T</sup>, T is the time constant in seconds
+ *
+ * Note: T = 1 / (2 pi f) where f is the cutoff frequency in Hz, the frequency
+ * below which the input starts to attenuate.
+ *
+ * This filter is stable for time constants greater than zero.
+ *
+ * @param timeConstant The discrete-time time constant in seconds.
+ * @param period The period in seconds between samples taken by the
+ * user.
+ */
+ static LinearFilter<T> HighPass(double timeConstant, units::second_t period) {
+ double gain = std::exp(-period.value() / timeConstant);
+ return LinearFilter({gain, -gain}, {-gain});
+ }
+
+ /**
+ * Creates a K-tap FIR moving average filter of the form:<br>
+ * y[n] = 1/k (x[k] + x[k-1] + … + x[0])
+ *
+ * This filter is always stable.
+ *
+ * @param taps The number of samples to average over. Higher = smoother but
+ * slower
+ */
+ static LinearFilter<T> MovingAverage(int taps) {
+ if (taps <= 0) {
+ throw std::runtime_error("Number of taps must be greater than zero.");
+ }
+
+ std::vector<double> gains(taps, 1.0 / taps);
+ return LinearFilter(gains, {});
+ }
+
+ /**
+ * Creates a backward finite difference filter that computes the nth
+ * derivative of the input given the specified number of samples.
+ *
+ * For example, a first derivative filter that uses two samples and a sample
+ * period of 20 ms would be
+ *
+ * <pre><code>
+ * LinearFilter<double>::BackwardFiniteDifference<1, 2>(20_ms);
+ * </code></pre>
+ *
+ * @tparam Derivative The order of the derivative to compute.
+ * @tparam Samples The number of samples to use to compute the given
+ * derivative. This must be one more than the order of
+ * derivative or higher.
+ * @param period The period in seconds between samples taken by the user.
+ */
+ template <int Derivative, int Samples>
+ static auto BackwardFiniteDifference(units::second_t period) {
+ // See
+ // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points
+ //
+ // For a given list of stencil points s of length n and the order of
+ // derivative d < n, the finite difference coefficients can be obtained by
+ // solving the following linear system for the vector a.
+ //
+ // @verbatim
+ // [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ]
+ // [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ]
+ // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d]
+ // @endverbatim
+ //
+ // where δᵢ,ⱼ are the Kronecker delta. For backward finite difference, the
+ // stencil points are the range [-n + 1, 0]. The FIR gains are the elements
+ // of the vector a in reverse order divided by hᵈ.
+ //
+ // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ).
+
+ static_assert(Derivative >= 1,
+ "Order of derivative must be greater than or equal to one.");
+ static_assert(Samples > 0, "Number of samples must be greater than zero.");
+ static_assert(Derivative < Samples,
+ "Order of derivative must be less than number of samples.");
+
+ Eigen::Matrix<double, Samples, Samples> S;
+ for (int row = 0; row < Samples; ++row) {
+ for (int col = 0; col < Samples; ++col) {
+ double s = 1 - Samples + col;
+ S(row, col) = std::pow(s, row);
+ }
+ }
+
+ // Fill in Kronecker deltas: https://en.wikipedia.org/wiki/Kronecker_delta
+ Eigen::Vector<double, Samples> d;
+ for (int i = 0; i < Samples; ++i) {
+ d(i) = (i == Derivative) ? Factorial(Derivative) : 0.0;
+ }
+
+ Eigen::Vector<double, Samples> a =
+ S.householderQr().solve(d) / std::pow(period.value(), Derivative);
+
+ // Reverse gains list
+ std::vector<double> gains;
+ for (int i = Samples - 1; i >= 0; --i) {
+ gains.push_back(a(i));
+ }
+
+ return LinearFilter(gains, {});
+ }
+
+ /**
+ * Reset the filter state.
+ */
+ void Reset() {
+ std::fill(m_inputs.begin(), m_inputs.end(), T{0.0});
+ std::fill(m_outputs.begin(), m_outputs.end(), T{0.0});
+ }
+
+ /**
+ * Calculates the next value of the filter.
+ *
+ * @param input Current input value.
+ *
+ * @return The filtered value at this step
+ */
+ T Calculate(T input) {
+ T retVal{0.0};
+
+ // Rotate the inputs
+ if (m_inputGains.size() > 0) {
+ m_inputs.push_front(input);
+ }
+
+ // Calculate the new value
+ for (size_t i = 0; i < m_inputGains.size(); ++i) {
+ retVal += m_inputs[i] * m_inputGains[i];
+ }
+ for (size_t i = 0; i < m_outputGains.size(); ++i) {
+ retVal -= m_outputs[i] * m_outputGains[i];
+ }
+
+ // Rotate the outputs
+ if (m_outputGains.size() > 0) {
+ m_outputs.push_front(retVal);
+ }
+
+ return retVal;
+ }
+
+ private:
+ wpi::circular_buffer<T> m_inputs;
+ wpi::circular_buffer<T> m_outputs;
+ std::vector<double> m_inputGains;
+ std::vector<double> m_outputGains;
+
+ /**
+ * Factorial of n.
+ *
+ * @param n Argument of which to take factorial.
+ */
+ static constexpr int Factorial(int n) {
+ if (n < 2) {
+ return 1;
+ } else {
+ return n * Factorial(n - 1);
+ }
+ }
+};
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/MedianFilter.h b/wpimath/src/main/native/include/frc/filter/MedianFilter.h
similarity index 78%
rename from wpimath/src/main/native/include/frc/MedianFilter.h
rename to wpimath/src/main/native/include/frc/filter/MedianFilter.h
index 3ccccbf..40422a6 100644
--- a/wpimath/src/main/native/include/frc/MedianFilter.h
+++ b/wpimath/src/main/native/include/frc/filter/MedianFilter.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -48,13 +45,13 @@
m_orderedValues.erase(std::find(m_orderedValues.begin(),
m_orderedValues.end(),
m_valueBuffer.pop_back()));
- curSize = curSize - 1;
+ --curSize;
}
// Add next value to circular buffer
m_valueBuffer.push_front(next);
- if (curSize % 2 == 1) {
+ if (curSize % 2 != 0) {
// If size is odd, return middle element of sorted list
return m_orderedValues[curSize / 2];
} else {
diff --git a/wpimath/src/main/native/include/frc/filter/SlewRateLimiter.h b/wpimath/src/main/native/include/frc/filter/SlewRateLimiter.h
new file mode 100644
index 0000000..f99c1af
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/filter/SlewRateLimiter.h
@@ -0,0 +1,73 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <algorithm>
+
+#include <wpi/timestamp.h>
+
+#include "units/time.h"
+
+namespace frc {
+/**
+ * A class that limits the rate of change of an input value. Useful for
+ * implementing voltage, setpoint, and/or output ramps. A slew-rate limit
+ * is most appropriate when the quantity being controlled is a velocity or
+ * a voltage; when controlling a position, consider using a TrapezoidProfile
+ * instead.
+ *
+ * @see TrapezoidProfile
+ */
+template <class Unit>
+class SlewRateLimiter {
+ public:
+ using Unit_t = units::unit_t<Unit>;
+ using Rate = units::compound_unit<Unit, units::inverse<units::seconds>>;
+ using Rate_t = units::unit_t<Rate>;
+
+ /**
+ * Creates a new SlewRateLimiter with the given rate limit and initial value.
+ *
+ * @param rateLimit The rate-of-change limit.
+ * @param initialValue The initial value of the input.
+ */
+ explicit SlewRateLimiter(Rate_t rateLimit, Unit_t initialValue = Unit_t{0})
+ : m_rateLimit{rateLimit},
+ m_prevVal{initialValue},
+ m_prevTime{units::microsecond_t(wpi::Now())} {}
+
+ /**
+ * Filters the input to limit its slew rate.
+ *
+ * @param input The input value whose slew rate is to be limited.
+ * @return The filtered value, which will not change faster than the slew
+ * rate.
+ */
+ Unit_t Calculate(Unit_t input) {
+ units::second_t currentTime = units::microsecond_t(wpi::Now());
+ units::second_t elapsedTime = currentTime - m_prevTime;
+ m_prevVal += std::clamp(input - m_prevVal, -m_rateLimit * elapsedTime,
+ m_rateLimit * elapsedTime);
+ m_prevTime = currentTime;
+ return m_prevVal;
+ }
+
+ /**
+ * Resets the slew rate limiter to the specified value; ignores the rate limit
+ * when doing so.
+ *
+ * @param value The value to reset to.
+ */
+ void Reset(Unit_t value) {
+ m_prevVal = value;
+ m_prevTime = units::microsecond_t(wpi::Now());
+ }
+
+ private:
+ Rate_t m_rateLimit;
+ Unit_t m_prevVal;
+ units::second_t m_prevTime;
+};
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/fmt/Eigen.h b/wpimath/src/main/native/include/frc/fmt/Eigen.h
new file mode 100644
index 0000000..f6cc7b6
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/fmt/Eigen.h
@@ -0,0 +1,66 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <fmt/format.h>
+
+#include "Eigen/Core"
+
+/**
+ * Formatter for Eigen::Matrix<>.
+ *
+ * @tparam Rows Number of rows.
+ * @tparam Cols Number of columns.
+ * @tparam Args Defaulted template arguments to Eigen::Matrix<>.
+ */
+template <int Rows, int Cols, int... Args>
+struct fmt::formatter<Eigen::Matrix<double, Rows, Cols, Args...>> {
+ /**
+ * Storage for format specifier.
+ */
+ char presentation = 'f';
+
+ /**
+ * Format string parser.
+ *
+ * @param ctx Format string context.
+ */
+ constexpr auto parse(fmt::format_parse_context& ctx) {
+ auto it = ctx.begin(), end = ctx.end();
+ if (it != end && (*it == 'f' || *it == 'e')) {
+ presentation = *it++;
+ }
+
+ if (it != end && *it != '}') {
+ throw fmt::format_error("invalid format");
+ }
+
+ return it;
+ }
+
+ /**
+ * Writes out a formatted matrix.
+ *
+ * @tparam FormatContext Format string context type.
+ * @param mat Matrix to format.
+ * @param ctx Format string context.
+ */
+ template <typename FormatContext>
+ auto format(const Eigen::Matrix<double, Rows, Cols, Args...>& mat,
+ FormatContext& ctx) {
+ auto out = ctx.out();
+ for (int i = 0; i < Rows; ++i) {
+ for (int j = 0; j < Cols; ++j) {
+ out = fmt::format_to(out, " {:f}", mat(i, j));
+ }
+
+ if (i < Rows - 1) {
+ out = fmt::format_to(out, "\n");
+ }
+ }
+
+ return out;
+ }
+};
diff --git a/wpimath/src/main/native/include/frc/fmt/Units.h b/wpimath/src/main/native/include/frc/fmt/Units.h
new file mode 100644
index 0000000..1ec61ca
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/fmt/Units.h
@@ -0,0 +1,218 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <fmt/format.h>
+
+#include "units/base.h"
+
+/**
+ * Formatter for unit types.
+ *
+ * @tparam Units Unit tag for which type of units the `unit_t` represents (e.g.
+ * meters).
+ * @tparam T Underlying type of the storage. Defaults to double.
+ * @tparam NonLinearScale Optional scale class for the units. Defaults to linear
+ * (i.e. does not scale the unit value). Examples of
+ * non-linear scales could be logarithmic, decibel, or
+ * richter scales. Non-linear scales must adhere to the
+ * non-linear-scale concept.
+ */
+template <class Units, typename T, template <typename> class NonLinearScale>
+struct fmt::formatter<units::unit_t<Units, T, NonLinearScale>>
+ : fmt::formatter<double> {
+ /**
+ * Writes out a formatted unit.
+ *
+ * @tparam FormatContext Format string context type.
+ * @param obj Unit instance.
+ * @param ctx Format string context.
+ */
+ template <typename FormatContext>
+ auto format(const units::unit_t<Units, T, NonLinearScale>& obj,
+ FormatContext& ctx) {
+ using BaseUnits =
+ units::unit<std::ratio<1>,
+ typename units::traits::unit_traits<Units>::base_unit_type>;
+
+ auto out = ctx.out();
+
+ out = fmt::formatter<double>::format(
+ units::convert<Units, BaseUnits>(obj()), ctx);
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::meter_ratio::num != 0) {
+ out = fmt::format_to(out, " m");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::meter_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::meter_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::meter_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::meter_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::meter_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::kilogram_ratio::num != 0) {
+ out = fmt::format_to(out, " kg");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::kilogram_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::kilogram_ratio::num != 1) {
+ out = fmt::format_to(out, "^{}",
+ units::traits::unit_traits<
+ Units>::base_unit_type::kilogram_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::kilogram_ratio::den != 1) {
+ out = fmt::format_to(out, "/{}",
+ units::traits::unit_traits<
+ Units>::base_unit_type::kilogram_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::second_ratio::num != 0) {
+ out = fmt::format_to(out, " s");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::second_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::second_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::second_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::second_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::second_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::ampere_ratio::num != 0) {
+ out = fmt::format_to(out, " A");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::ampere_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::ampere_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::ampere_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::ampere_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::ampere_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::kelvin_ratio::num != 0) {
+ out = fmt::format_to(out, " K");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::kelvin_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::kelvin_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::kelvin_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::kelvin_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::kelvin_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::mole_ratio::num != 0) {
+ out = fmt::format_to(out, " mol");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::mole_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::mole_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::mole_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::mole_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::mole_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::candela_ratio::num != 0) {
+ out = fmt::format_to(out, " cd");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::candela_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::candela_ratio::num != 1) {
+ out = fmt::format_to(out, "^{}",
+ units::traits::unit_traits<
+ Units>::base_unit_type::candela_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::candela_ratio::den != 1) {
+ out = fmt::format_to(out, "/{}",
+ units::traits::unit_traits<
+ Units>::base_unit_type::candela_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::radian_ratio::num != 0) {
+ out = fmt::format_to(out, " rad");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::radian_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::radian_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::radian_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::radian_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::radian_ratio::den);
+ }
+
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::byte_ratio::num != 0) {
+ out = fmt::format_to(out, " b");
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::byte_ratio::num != 0 &&
+ units::traits::unit_traits<
+ Units>::base_unit_type::byte_ratio::num != 1) {
+ out = fmt::format_to(
+ out, "^{}",
+ units::traits::unit_traits<Units>::base_unit_type::byte_ratio::num);
+ }
+ if constexpr (units::traits::unit_traits<
+ Units>::base_unit_type::byte_ratio::den != 1) {
+ out = fmt::format_to(
+ out, "/{}",
+ units::traits::unit_traits<Units>::base_unit_type::byte_ratio::den);
+ }
+
+ return out;
+ }
+};
diff --git a/wpimath/src/main/native/include/frc/geometry/Pose2d.h b/wpimath/src/main/native/include/frc/geometry/Pose2d.h
index 43f4756..ebaa7c1 100644
--- a/wpimath/src/main/native/include/frc/geometry/Pose2d.h
+++ b/wpimath/src/main/native/include/frc/geometry/Pose2d.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "Transform2d.h"
#include "Translation2d.h"
#include "Twist2d.h"
@@ -20,7 +19,7 @@
/**
* Represents a 2d pose containing translational and rotational elements.
*/
-class Pose2d {
+class WPILIB_DLLEXPORT Pose2d {
public:
/**
* Constructs a pose at the origin facing toward the positive X axis.
@@ -61,18 +60,6 @@
Pose2d operator+(const Transform2d& other) const;
/**
- * Transforms the current pose by the transformation.
- *
- * This is similar to the + operator, except that it mutates the current
- * object.
- *
- * @param other The transform to transform the pose by.
- *
- * @return Reference to the new mutated object.
- */
- Pose2d& operator+=(const Transform2d& other);
-
- /**
* Returns the Transform2d that maps the one pose to another.
*
* @param other The initial pose of the transformation.
@@ -186,8 +173,10 @@
Rotation2d m_rotation;
};
+WPILIB_DLLEXPORT
void to_json(wpi::json& json, const Pose2d& pose);
+WPILIB_DLLEXPORT
void from_json(const wpi::json& json, Pose2d& pose);
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/geometry/Rotation2d.h b/wpimath/src/main/native/include/frc/geometry/Rotation2d.h
index 914eba4..94a17fc 100644
--- a/wpimath/src/main/native/include/frc/geometry/Rotation2d.h
+++ b/wpimath/src/main/native/include/frc/geometry/Rotation2d.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "units/angle.h"
namespace wpi {
@@ -19,7 +18,7 @@
* A rotation in a 2d coordinate frame represented a point on the unit circle
* (cosine and sine).
*/
-class Rotation2d {
+class WPILIB_DLLEXPORT Rotation2d {
public:
/**
* Constructs a Rotation2d with a default angle of 0 degrees.
@@ -31,14 +30,14 @@
*
* @param value The value of the angle in radians.
*/
- Rotation2d(units::radian_t value); // NOLINT(runtime/explicit)
+ Rotation2d(units::radian_t value); // NOLINT
/**
* Constructs a Rotation2d with the given degree value.
*
* @param value The value of the angle in degrees.
*/
- Rotation2d(units::degree_t value); // NOLINT(runtime/explicit)
+ Rotation2d(units::degree_t value); // NOLINT
/**
* Constructs a Rotation2d with the given x and y (cosine and sine)
@@ -63,18 +62,6 @@
Rotation2d operator+(const Rotation2d& other) const;
/**
- * Adds a rotation to the current rotation.
- *
- * This is similar to the + operator except that it mutates the current
- * object.
- *
- * @param other The rotation to add.
- *
- * @return The reference to the new mutated object.
- */
- Rotation2d& operator+=(const Rotation2d& other);
-
- /**
* Subtracts the new rotation from the current rotation and returns the new
* rotation.
*
@@ -88,18 +75,6 @@
Rotation2d operator-(const Rotation2d& other) const;
/**
- * Subtracts the new rotation from the current rotation.
- *
- * This is similar to the - operator except that it mutates the current
- * object.
- *
- * @param other The rotation to subtract.
- *
- * @return The reference to the new mutated object.
- */
- Rotation2d& operator-=(const Rotation2d& other);
-
- /**
* Takes the inverse of the current rotation. This is simply the negative of
* the current angular value.
*
@@ -134,10 +109,11 @@
/**
* Adds the new rotation to the current rotation using a rotation matrix.
*
+ * <pre>
* [cos_new] [other.cos, -other.sin][cos]
* [sin_new] = [other.sin, other.cos][sin]
- *
- * value_new = std::atan2(cos_new, sin_new)
+ * value_new = std::atan2(sin_new, cos_new)
+ * </pre>
*
* @param other The rotation to rotate by.
*
@@ -186,8 +162,10 @@
double m_sin = 0;
};
+WPILIB_DLLEXPORT
void to_json(wpi::json& json, const Rotation2d& rotation);
+WPILIB_DLLEXPORT
void from_json(const wpi::json& json, Rotation2d& rotation);
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/geometry/Transform2d.h b/wpimath/src/main/native/include/frc/geometry/Transform2d.h
index 8f05413..3d5e1d6 100644
--- a/wpimath/src/main/native/include/frc/geometry/Transform2d.h
+++ b/wpimath/src/main/native/include/frc/geometry/Transform2d.h
@@ -1,22 +1,21 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "Translation2d.h"
namespace frc {
-class Pose2d;
+class WPILIB_DLLEXPORT Pose2d;
/**
* Represents a transformation for a Pose2d.
*/
-class Transform2d {
+class WPILIB_DLLEXPORT Transform2d {
public:
/**
* Constructs the transform that maps the initial pose to the final pose.
@@ -85,6 +84,14 @@
}
/**
+ * Composes two transformations.
+ *
+ * @param other The transform to compose with this one.
+ * @return The composition of the two transformations.
+ */
+ Transform2d operator+(const Transform2d& other) const;
+
+ /**
* Checks equality between this Transform2d and another object.
*
* @param other The other object.
diff --git a/wpimath/src/main/native/include/frc/geometry/Translation2d.h b/wpimath/src/main/native/include/frc/geometry/Translation2d.h
index 0c3ee3d..204da30 100644
--- a/wpimath/src/main/native/include/frc/geometry/Translation2d.h
+++ b/wpimath/src/main/native/include/frc/geometry/Translation2d.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "Rotation2d.h"
#include "units/length.h"
@@ -24,7 +23,7 @@
* When the robot is placed on the origin, facing toward the X direction,
* moving forward increases the X, whereas moving to the left increases the Y.
*/
-class Translation2d {
+class WPILIB_DLLEXPORT Translation2d {
public:
/**
* Constructs a Translation2d with X and Y components equal to zero.
@@ -114,18 +113,6 @@
Translation2d operator+(const Translation2d& other) const;
/**
- * Adds the new translation to the current translation.
- *
- * This is similar to the + operator, except that the current object is
- * mutated.
- *
- * @param other The translation to add.
- *
- * @return The reference to the new mutated object.
- */
- Translation2d& operator+=(const Translation2d& other);
-
- /**
* Subtracts the other translation from the other translation and returns the
* difference.
*
@@ -139,18 +126,6 @@
Translation2d operator-(const Translation2d& other) const;
/**
- * Subtracts the new translation from the current translation.
- *
- * This is similar to the - operator, except that the current object is
- * mutated.
- *
- * @param other The translation to subtract.
- *
- * @return The reference to the new mutated object.
- */
- Translation2d& operator-=(const Translation2d& other);
-
- /**
* Returns the inverse of the current translation. This is equivalent to
* rotating by 180 degrees, flipping the point over both axes, or simply
* negating both components of the translation.
@@ -171,17 +146,6 @@
Translation2d operator*(double scalar) const;
/**
- * Multiplies the current translation by a scalar.
- *
- * This is similar to the * operator, except that current object is mutated.
- *
- * @param scalar The scalar to multiply by.
- *
- * @return The reference to the new mutated object.
- */
- Translation2d& operator*=(double scalar);
-
- /**
* Divides the translation by a scalar and returns the new translation.
*
* For example, Translation2d{2.0, 2.5} / 2 = Translation2d{1.0, 1.25}
@@ -208,24 +172,15 @@
*/
bool operator!=(const Translation2d& other) const;
- /*
- * Divides the current translation by a scalar.
- *
- * This is similar to the / operator, except that current object is mutated.
- *
- * @param scalar The scalar to divide by.
- *
- * @return The reference to the new mutated object.
- */
- Translation2d& operator/=(double scalar);
-
private:
units::meter_t m_x = 0_m;
units::meter_t m_y = 0_m;
};
+WPILIB_DLLEXPORT
void to_json(wpi::json& json, const Translation2d& state);
+WPILIB_DLLEXPORT
void from_json(const wpi::json& json, Translation2d& state);
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/geometry/Twist2d.h b/wpimath/src/main/native/include/frc/geometry/Twist2d.h
index b71ee56..9d7a856 100644
--- a/wpimath/src/main/native/include/frc/geometry/Twist2d.h
+++ b/wpimath/src/main/native/include/frc/geometry/Twist2d.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "units/angle.h"
#include "units/length.h"
#include "units/math.h"
@@ -19,7 +18,7 @@
*
* A Twist can be used to represent a difference between two poses.
*/
-struct Twist2d {
+struct WPILIB_DLLEXPORT Twist2d {
/**
* Linear "dx" component
*/
diff --git a/wpimath/src/main/native/include/frc/kinematics/ChassisSpeeds.h b/wpimath/src/main/native/include/frc/kinematics/ChassisSpeeds.h
index 1716ca7..7414dec 100644
--- a/wpimath/src/main/native/include/frc/kinematics/ChassisSpeeds.h
+++ b/wpimath/src/main/native/include/frc/kinematics/ChassisSpeeds.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/geometry/Rotation2d.h"
#include "units/angular_velocity.h"
#include "units/velocity.h"
@@ -23,7 +22,7 @@
* never have a dy component because it can never move sideways. Holonomic
* drivetrains such as swerve and mecanum will often have all three components.
*/
-struct ChassisSpeeds {
+struct WPILIB_DLLEXPORT ChassisSpeeds {
/**
* Represents forward velocity w.r.t the robot frame of reference. (Fwd is +)
*/
diff --git a/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveKinematics.h b/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveKinematics.h
index 9e48b5e..4bf27ff 100644
--- a/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveKinematics.h
+++ b/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveKinematics.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/kinematics/ChassisSpeeds.h"
#include "frc/kinematics/DifferentialDriveWheelSpeeds.h"
#include "units/angle.h"
@@ -22,7 +21,7 @@
* velocity components whereas forward kinematics converts left and right
* component velocities into a linear and angular chassis speed.
*/
-class DifferentialDriveKinematics {
+class WPILIB_DLLEXPORT DifferentialDriveKinematics {
public:
/**
* Constructs a differential drive kinematics object.
diff --git a/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveOdometry.h b/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveOdometry.h
index a65b52a..70179de 100644
--- a/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveOdometry.h
+++ b/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveOdometry.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "DifferentialDriveKinematics.h"
#include "frc/geometry/Pose2d.h"
#include "units/length.h"
@@ -24,7 +23,7 @@
* It is important that you reset your encoders to zero before using this class.
* Any subsequent pose resets also require the encoders to be reset to zero.
*/
-class DifferentialDriveOdometry {
+class WPILIB_DLLEXPORT DifferentialDriveOdometry {
public:
/**
* Constructs a DifferentialDriveOdometry object.
diff --git a/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveWheelSpeeds.h b/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveWheelSpeeds.h
index 20085bb..2bf9fb9 100644
--- a/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveWheelSpeeds.h
+++ b/wpimath/src/main/native/include/frc/kinematics/DifferentialDriveWheelSpeeds.h
@@ -1,19 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "units/velocity.h"
namespace frc {
/**
* Represents the wheel speeds for a differential drive drivetrain.
*/
-struct DifferentialDriveWheelSpeeds {
+struct WPILIB_DLLEXPORT DifferentialDriveWheelSpeeds {
/**
* Speed of the left side of the robot.
*/
diff --git a/wpimath/src/main/native/include/frc/kinematics/MecanumDriveKinematics.h b/wpimath/src/main/native/include/frc/kinematics/MecanumDriveKinematics.h
index a0ccb52..9a7cef9 100644
--- a/wpimath/src/main/native/include/frc/kinematics/MecanumDriveKinematics.h
+++ b/wpimath/src/main/native/include/frc/kinematics/MecanumDriveKinematics.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "Eigen/Core"
#include "Eigen/QR"
#include "frc/geometry/Translation2d.h"
@@ -38,7 +37,7 @@
* Forward kinematics is also used for odometry -- determining the position of
* the robot on the field using encoders and a gyro.
*/
-class MecanumDriveKinematics {
+class WPILIB_DLLEXPORT MecanumDriveKinematics {
public:
/**
* Constructs a mecanum drive kinematics object.
diff --git a/wpimath/src/main/native/include/frc/kinematics/MecanumDriveOdometry.h b/wpimath/src/main/native/include/frc/kinematics/MecanumDriveOdometry.h
index 546a498..bdd1239 100644
--- a/wpimath/src/main/native/include/frc/kinematics/MecanumDriveOdometry.h
+++ b/wpimath/src/main/native/include/frc/kinematics/MecanumDriveOdometry.h
@@ -1,12 +1,10 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
#include <wpi/timestamp.h>
#include "frc/geometry/Pose2d.h"
@@ -25,7 +23,7 @@
* path following. Furthermore, odometry can be used for latency compensation
* when using computer-vision systems.
*/
-class MecanumDriveOdometry {
+class WPILIB_DLLEXPORT MecanumDriveOdometry {
public:
/**
* Constructs a MecanumDriveOdometry object.
diff --git a/wpimath/src/main/native/include/frc/kinematics/MecanumDriveWheelSpeeds.h b/wpimath/src/main/native/include/frc/kinematics/MecanumDriveWheelSpeeds.h
index aa82b99..c24b134 100644
--- a/wpimath/src/main/native/include/frc/kinematics/MecanumDriveWheelSpeeds.h
+++ b/wpimath/src/main/native/include/frc/kinematics/MecanumDriveWheelSpeeds.h
@@ -1,19 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "units/velocity.h"
namespace frc {
/**
* Represents the wheel speeds for a mecanum drive drivetrain.
*/
-struct MecanumDriveWheelSpeeds {
+struct WPILIB_DLLEXPORT MecanumDriveWheelSpeeds {
/**
* Speed of the front-left wheel.
*/
diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h
index 0ed50b4..84057ac 100644
--- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h
+++ b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.h
@@ -1,15 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
#include <cstddef>
+#include <wpi/array.h>
+
#include "Eigen/Core"
#include "Eigen/QR"
#include "frc/geometry/Rotation2d.h"
@@ -53,8 +51,10 @@
* pass in the module states in the same order when calling the forward
* kinematics methods.
*
- * @param wheels The locations of the wheels relative to the physical center
- * of the robot.
+ * @param wheel The location of the first wheel relative to the physical
+ * center of the robot.
+ * @param wheels The locations of the other wheels relative to the physical
+ * center of the robot.
*/
template <typename... Wheels>
explicit SwerveDriveKinematics(Translation2d wheel, Wheels&&... wheels)
@@ -65,8 +65,25 @@
for (size_t i = 0; i < NumModules; i++) {
// clang-format off
m_inverseKinematics.template block<2, 3>(i * 2, 0) <<
- 1, 0, (-m_modules[i].Y()).template to<double>(),
- 0, 1, (+m_modules[i].X()).template to<double>();
+ 1, 0, (-m_modules[i].Y()).value(),
+ 0, 1, (+m_modules[i].X()).value();
+ // clang-format on
+ }
+
+ m_forwardKinematics = m_inverseKinematics.householderQr();
+
+ wpi::math::MathSharedStore::ReportUsage(
+ wpi::math::MathUsageId::kKinematics_SwerveDrive, 1);
+ }
+
+ explicit SwerveDriveKinematics(
+ const wpi::array<Translation2d, NumModules>& wheels)
+ : m_modules{wheels} {
+ for (size_t i = 0; i < NumModules; i++) {
+ // clang-format off
+ m_inverseKinematics.template block<2, 3>(i * 2, 0) <<
+ 1, 0, (-m_modules[i].Y()).value(),
+ 0, 1, (+m_modules[i].X()).value();
// clang-format on
}
@@ -97,15 +114,16 @@
* @return An array containing the module states. Use caution because these
* module states are not normalized. Sometimes, a user input may cause one of
* the module speeds to go above the attainable max velocity. Use the
- * <NormalizeWheelSpeeds> function to rectify this issue. In addition, you can
- * leverage the power of C++17 to directly assign the module states to
+ * NormalizeWheelSpeeds(wpi::array<SwerveModuleState, NumModules>*,
+ * units::meters_per_second_t) function to rectify this issue. In addition,
+ * you can leverage the power of C++17 to directly assign the module states to
* variables:
*
* @code{.cpp}
* auto [fl, fr, bl, br] = kinematics.ToSwerveModuleStates(chassisSpeeds);
* @endcode
*/
- std::array<SwerveModuleState, NumModules> ToSwerveModuleStates(
+ wpi::array<SwerveModuleState, NumModules> ToSwerveModuleStates(
const ChassisSpeeds& chassisSpeeds,
const Translation2d& centerOfRotation = Translation2d()) const;
@@ -130,7 +148,7 @@
* the robot's position on the field using data from the real-world speed and
* angle of each module on the robot.
*
- * @param moduleStates The state of the modules as an std::array of type
+ * @param moduleStates The state of the modules as an wpi::array of type
* SwerveModuleState, NumModules long as measured from respective encoders
* and gyros. The order of the swerve module states should be same as passed
* into the constructor of this class.
@@ -138,7 +156,7 @@
* @return The resulting chassis speed.
*/
ChassisSpeeds ToChassisSpeeds(
- std::array<SwerveModuleState, NumModules> moduleStates) const;
+ wpi::array<SwerveModuleState, NumModules> moduleStates) const;
/**
* Normalizes the wheel speeds using some max attainable speed. Sometimes,
@@ -153,14 +171,14 @@
* @param attainableMaxSpeed The absolute max speed that a module can reach.
*/
static void NormalizeWheelSpeeds(
- std::array<SwerveModuleState, NumModules>* moduleStates,
+ wpi::array<SwerveModuleState, NumModules>* moduleStates,
units::meters_per_second_t attainableMaxSpeed);
private:
mutable Eigen::Matrix<double, NumModules * 2, 3> m_inverseKinematics;
Eigen::HouseholderQR<Eigen::Matrix<double, NumModules * 2, 3>>
m_forwardKinematics;
- std::array<Translation2d, NumModules> m_modules;
+ wpi::array<Translation2d, NumModules> m_modules;
mutable Translation2d m_previousCoR;
};
diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc
index 08eba50..1747453 100644
--- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc
+++ b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveKinematics.inc
@@ -1,14 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <algorithm>
+#include "frc/kinematics/SwerveDriveKinematics.h"
#include "units/math.h"
namespace frc {
@@ -18,7 +16,7 @@
-> SwerveDriveKinematics<1 + sizeof...(Wheels)>;
template <size_t NumModules>
-std::array<SwerveModuleState, NumModules>
+wpi::array<SwerveModuleState, NumModules>
SwerveDriveKinematics<NumModules>::ToSwerveModuleStates(
const ChassisSpeeds& chassisSpeeds,
const Translation2d& centerOfRotation) const {
@@ -26,30 +24,29 @@
if (centerOfRotation != m_previousCoR) {
for (size_t i = 0; i < NumModules; i++) {
// clang-format off
- m_inverseKinematics.template block<2, 3>(i * 2, 0) <<
- 1, 0, (-m_modules[i].Y() + centerOfRotation.Y()).template to<double>(),
- 0, 1, (+m_modules[i].X() - centerOfRotation.X()).template to<double>();
+ m_inverseKinematics.template block<2, 3>(i * 2, 0) =
+ Eigen::Matrix<double, 2, 3>{
+ {1, 0, (-m_modules[i].Y() + centerOfRotation.Y()).value()},
+ {0, 1, (+m_modules[i].X() - centerOfRotation.X()).value()}};
// clang-format on
}
m_previousCoR = centerOfRotation;
}
- Eigen::Vector3d chassisSpeedsVector;
- chassisSpeedsVector << chassisSpeeds.vx.to<double>(),
- chassisSpeeds.vy.to<double>(), chassisSpeeds.omega.to<double>();
+ Eigen::Vector3d chassisSpeedsVector{chassisSpeeds.vx.value(),
+ chassisSpeeds.vy.value(),
+ chassisSpeeds.omega.value()};
Eigen::Matrix<double, NumModules * 2, 1> moduleStatesMatrix =
m_inverseKinematics * chassisSpeedsVector;
- std::array<SwerveModuleState, NumModules> moduleStates;
+ wpi::array<SwerveModuleState, NumModules> moduleStates{wpi::empty_array};
for (size_t i = 0; i < NumModules; i++) {
- units::meters_per_second_t x =
- units::meters_per_second_t{moduleStatesMatrix(i * 2, 0)};
- units::meters_per_second_t y =
- units::meters_per_second_t{moduleStatesMatrix(i * 2 + 1, 0)};
+ units::meters_per_second_t x{moduleStatesMatrix(i * 2, 0)};
+ units::meters_per_second_t y{moduleStatesMatrix(i * 2 + 1, 0)};
auto speed = units::math::hypot(x, y);
- Rotation2d rotation{x.to<double>(), y.to<double>()};
+ Rotation2d rotation{x.value(), y.value()};
moduleStates[i] = {speed, rotation};
}
@@ -65,22 +62,21 @@
"Number of modules is not consistent with number of wheel "
"locations provided in constructor.");
- std::array<SwerveModuleState, NumModules> moduleStates{wheelStates...};
+ wpi::array<SwerveModuleState, NumModules> moduleStates{wheelStates...};
return this->ToChassisSpeeds(moduleStates);
}
template <size_t NumModules>
ChassisSpeeds SwerveDriveKinematics<NumModules>::ToChassisSpeeds(
- std::array<SwerveModuleState, NumModules> moduleStates) const {
+ wpi::array<SwerveModuleState, NumModules> moduleStates) const {
Eigen::Matrix<double, NumModules * 2, 1> moduleStatesMatrix;
- for (size_t i = 0; i < NumModules; i++) {
+ for (size_t i = 0; i < NumModules; ++i) {
SwerveModuleState module = moduleStates[i];
- moduleStatesMatrix.row(i * 2)
- << module.speed.to<double>() * module.angle.Cos();
- moduleStatesMatrix.row(i * 2 + 1)
- << module.speed.to<double>() * module.angle.Sin();
+ moduleStatesMatrix(i * 2, 0) = module.speed.value() * module.angle.Cos();
+ moduleStatesMatrix(i * 2 + 1, 0) =
+ module.speed.value() * module.angle.Sin();
}
Eigen::Vector3d chassisSpeedsVector =
@@ -93,7 +89,7 @@
template <size_t NumModules>
void SwerveDriveKinematics<NumModules>::NormalizeWheelSpeeds(
- std::array<SwerveModuleState, NumModules>* moduleStates,
+ wpi::array<SwerveModuleState, NumModules>* moduleStates,
units::meters_per_second_t attainableMaxSpeed) {
auto& states = *moduleStates;
auto realMaxSpeed = std::max_element(states.begin(), states.end(),
diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h
index 03591da..d1a4958 100644
--- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h
+++ b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc
index e7bb093..96db930 100644
--- a/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc
+++ b/wpimath/src/main/native/include/frc/kinematics/SwerveDriveOdometry.inc
@@ -1,12 +1,10 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include "frc/kinematics/SwerveDriveOdometry.h"
#include "wpimath/MathShared.h"
namespace frc {
diff --git a/wpimath/src/main/native/include/frc/kinematics/SwerveModuleState.h b/wpimath/src/main/native/include/frc/kinematics/SwerveModuleState.h
index b5ae7f1..cae2d53 100644
--- a/wpimath/src/main/native/include/frc/kinematics/SwerveModuleState.h
+++ b/wpimath/src/main/native/include/frc/kinematics/SwerveModuleState.h
@@ -1,20 +1,21 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/geometry/Rotation2d.h"
+#include "units/angle.h"
+#include "units/math.h"
#include "units/velocity.h"
namespace frc {
/**
* Represents the state of one swerve module.
*/
-struct SwerveModuleState {
+struct WPILIB_DLLEXPORT SwerveModuleState {
/**
* Speed of the wheel of the module.
*/
@@ -24,5 +25,24 @@
* Angle of the module.
*/
Rotation2d angle;
+
+ /**
+ * Minimize the change in heading the desired swerve module state would
+ * require by potentially reversing the direction the wheel spins. If this is
+ * used with the PIDController class's continuous input functionality, the
+ * furthest a wheel will ever rotate is 90 degrees.
+ *
+ * @param desiredState The desired state.
+ * @param currentAngle The current module angle.
+ */
+ static SwerveModuleState Optimize(const SwerveModuleState& desiredState,
+ const Rotation2d& currentAngle) {
+ auto delta = desiredState.angle - currentAngle;
+ if (units::math::abs(delta.Degrees()) > 90_deg) {
+ return {-desiredState.speed, desiredState.angle + Rotation2d{180_deg}};
+ } else {
+ return {desiredState.speed, desiredState.angle};
+ }
+ }
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/spline/CubicHermiteSpline.h b/wpimath/src/main/native/include/frc/spline/CubicHermiteSpline.h
index c9cf2d0..8126349 100644
--- a/wpimath/src/main/native/include/frc/spline/CubicHermiteSpline.h
+++ b/wpimath/src/main/native/include/frc/spline/CubicHermiteSpline.h
@@ -1,13 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
#include "Eigen/Core"
#include "frc/spline/Spline.h"
@@ -16,7 +14,7 @@
/**
* Represents a hermite spline of degree 3.
*/
-class CubicHermiteSpline : public Spline<3> {
+class WPILIB_DLLEXPORT CubicHermiteSpline : public Spline<3> {
public:
/**
* Constructs a cubic hermite spline with the specified control vectors. Each
@@ -32,10 +30,10 @@
* @param yFinalControlVector The control vector for the final point in
* the y dimension.
*/
- CubicHermiteSpline(std::array<double, 2> xInitialControlVector,
- std::array<double, 2> xFinalControlVector,
- std::array<double, 2> yInitialControlVector,
- std::array<double, 2> yFinalControlVector);
+ CubicHermiteSpline(wpi::array<double, 2> xInitialControlVector,
+ wpi::array<double, 2> xFinalControlVector,
+ wpi::array<double, 2> yInitialControlVector,
+ wpi::array<double, 2> yFinalControlVector);
protected:
/**
@@ -55,13 +53,31 @@
* @return The hermite basis matrix for cubic hermite spline interpolation.
*/
static Eigen::Matrix<double, 4, 4> MakeHermiteBasis() {
- // clang-format off
- static auto basis = (Eigen::Matrix<double, 4, 4>() <<
- +2.0, +1.0, -2.0, +1.0,
- -3.0, -2.0, +3.0, -1.0,
- +0.0, +1.0, +0.0, +0.0,
- +1.0, +0.0, +0.0, +0.0).finished();
- // clang-format on
+ // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
+ // the coefficients of the spline P(t) = a3 * t^3 + a2 * t^2 + a1 * t + a0.
+ //
+ // P(i) = P(0) = a0
+ // P'(i) = P'(0) = a1
+ // P(i+1) = P(1) = a3 + a2 + a1 + a0
+ // P'(i+1) = P'(1) = 3 * a3 + 2 * a2 + a1
+ //
+ // [ P(i) ] = [ 0 0 0 1 ][ a3 ]
+ // [ P'(i) ] = [ 0 0 1 0 ][ a2 ]
+ // [ P(i+1) ] = [ 1 1 1 1 ][ a1 ]
+ // [ P'(i+1) ] = [ 3 2 1 0 ][ a0 ]
+ //
+ // To solve for the coefficients, we can invert the 4x4 matrix and move it
+ // to the other side of the equation.
+ //
+ // [ a3 ] = [ 2 1 -2 1 ][ P(i) ]
+ // [ a2 ] = [ -3 -2 3 -1 ][ P'(i) ]
+ // [ a1 ] = [ 0 1 0 0 ][ P(i+1) ]
+ // [ a0 ] = [ 1 0 0 0 ][ P'(i+1) ]
+
+ static const Eigen::Matrix<double, 4, 4> basis{{+2.0, +1.0, -2.0, +1.0},
+ {-3.0, -2.0, +3.0, -1.0},
+ {+0.0, +1.0, +0.0, +0.0},
+ {+1.0, +0.0, +0.0, +0.0}};
return basis;
}
@@ -75,10 +91,9 @@
* @return The control vector matrix for a dimension.
*/
static Eigen::Vector4d ControlVectorFromArrays(
- std::array<double, 2> initialVector, std::array<double, 2> finalVector) {
- return (Eigen::Vector4d() << initialVector[0], initialVector[1],
- finalVector[0], finalVector[1])
- .finished();
+ wpi::array<double, 2> initialVector, wpi::array<double, 2> finalVector) {
+ return Eigen::Vector4d{initialVector[0], initialVector[1], finalVector[0],
+ finalVector[1]};
}
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/spline/QuinticHermiteSpline.h b/wpimath/src/main/native/include/frc/spline/QuinticHermiteSpline.h
index 201c402..5ba3e2a 100644
--- a/wpimath/src/main/native/include/frc/spline/QuinticHermiteSpline.h
+++ b/wpimath/src/main/native/include/frc/spline/QuinticHermiteSpline.h
@@ -1,13 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
#include "Eigen/Core"
#include "frc/spline/Spline.h"
@@ -16,7 +14,7 @@
/**
* Represents a hermite spline of degree 5.
*/
-class QuinticHermiteSpline : public Spline<5> {
+class WPILIB_DLLEXPORT QuinticHermiteSpline : public Spline<5> {
public:
/**
* Constructs a quintic hermite spline with the specified control vectors.
@@ -32,10 +30,10 @@
* @param yFinalControlVector The control vector for the final point in
* the y dimension.
*/
- QuinticHermiteSpline(std::array<double, 3> xInitialControlVector,
- std::array<double, 3> xFinalControlVector,
- std::array<double, 3> yInitialControlVector,
- std::array<double, 3> yFinalControlVector);
+ QuinticHermiteSpline(wpi::array<double, 3> xInitialControlVector,
+ wpi::array<double, 3> xFinalControlVector,
+ wpi::array<double, 3> yInitialControlVector,
+ wpi::array<double, 3> yFinalControlVector);
protected:
/**
@@ -55,15 +53,41 @@
* @return The hermite basis matrix for quintic hermite spline interpolation.
*/
static Eigen::Matrix<double, 6, 6> MakeHermiteBasis() {
- // clang-format off
- static const auto basis = (Eigen::Matrix<double, 6, 6>() <<
- -06.0, -03.0, -00.5, +06.0, -03.0, +00.5,
- +15.0, +08.0, +01.5, -15.0, +07.0, +01.0,
- -10.0, -06.0, -01.5, +10.0, -04.0, +00.5,
- +00.0, +00.0, +00.5, +00.0, +00.0, +00.0,
- +00.0, +01.0, +00.0, +00.0, +00.0, +00.0,
- +01.0, +00.0, +00.0, +00.0, +00.0, +00.0).finished();
- // clang-format on
+ // Given P(i), P'(i), P''(i), P(i+1), P'(i+1), P''(i+1), the control
+ // vectors, we want to find the coefficients of the spline
+ // P(t) = a5 * t^5 + a4 * t^4 + a3 * t^3 + a2 * t^2 + a1 * t + a0.
+ //
+ // P(i) = P(0) = a0
+ // P'(i) = P'(0) = a1
+ // P''(i) = P''(0) = 2 * a2
+ // P(i+1) = P(1) = a5 + a4 + a3 + a2 + a1 + a0
+ // P'(i+1) = P'(1) = 5 * a5 + 4 * a4 + 3 * a3 + 2 * a2 + a1
+ // P''(i+1) = P''(1) = 20 * a5 + 12 * a4 + 6 * a3 + 2 * a2
+ //
+ // [ P(i) ] = [ 0 0 0 0 0 1 ][ a5 ]
+ // [ P'(i) ] = [ 0 0 0 0 1 0 ][ a4 ]
+ // [ P''(i) ] = [ 0 0 0 2 0 0 ][ a3 ]
+ // [ P(i+1) ] = [ 1 1 1 1 1 1 ][ a2 ]
+ // [ P'(i+1) ] = [ 5 4 3 2 1 0 ][ a1 ]
+ // [ P''(i+1) ] = [ 20 12 6 2 0 0 ][ a0 ]
+ //
+ // To solve for the coefficients, we can invert the 6x6 matrix and move it
+ // to the other side of the equation.
+ //
+ // [ a5 ] = [ -6.0 -3.0 -0.5 6.0 -3.0 0.5 ][ P(i) ]
+ // [ a4 ] = [ 15.0 8.0 1.5 -15.0 7.0 -1.0 ][ P'(i) ]
+ // [ a3 ] = [ -10.0 -6.0 -1.5 10.0 -4.0 0.5 ][ P''(i) ]
+ // [ a2 ] = [ 0.0 0.0 0.5 0.0 0.0 0.0 ][ P(i+1) ]
+ // [ a1 ] = [ 0.0 1.0 0.0 0.0 0.0 0.0 ][ P'(i+1) ]
+ // [ a0 ] = [ 1.0 0.0 0.0 0.0 0.0 0.0 ][ P''(i+1) ]
+
+ static const Eigen::Matrix<double, 6, 6> basis{
+ {-06.0, -03.0, -00.5, +06.0, -03.0, +00.5},
+ {+15.0, +08.0, +01.5, -15.0, +07.0, -01.0},
+ {-10.0, -06.0, -01.5, +10.0, -04.0, +00.5},
+ {+00.0, +00.0, +00.5, +00.0, +00.0, +00.0},
+ {+00.0, +01.0, +00.0, +00.0, +00.0, +00.0},
+ {+01.0, +00.0, +00.0, +00.0, +00.0, +00.0}};
return basis;
}
@@ -76,11 +100,11 @@
*
* @return The control vector matrix for a dimension.
*/
- static Eigen::Matrix<double, 6, 1> ControlVectorFromArrays(
- std::array<double, 3> initialVector, std::array<double, 3> finalVector) {
- return (Eigen::Matrix<double, 6, 1>() << initialVector[0], initialVector[1],
- initialVector[2], finalVector[0], finalVector[1], finalVector[2])
- .finished();
+ static Eigen::Vector<double, 6> ControlVectorFromArrays(
+ wpi::array<double, 3> initialVector, wpi::array<double, 3> finalVector) {
+ return Eigen::Vector<double, 6>{initialVector[0], initialVector[1],
+ initialVector[2], finalVector[0],
+ finalVector[1], finalVector[2]};
}
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/spline/Spline.h b/wpimath/src/main/native/include/frc/spline/Spline.h
index 2964476..2dd248a 100644
--- a/wpimath/src/main/native/include/frc/spline/Spline.h
+++ b/wpimath/src/main/native/include/frc/spline/Spline.h
@@ -1,16 +1,14 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
#include <utility>
#include <vector>
+#include <wpi/array.h>
+
#include "Eigen/Core"
#include "frc/geometry/Pose2d.h"
#include "units/curvature.h"
@@ -46,8 +44,8 @@
* dimension.
*/
struct ControlVector {
- std::array<double, (Degree + 1) / 2> x;
- std::array<double, (Degree + 1) / 2> y;
+ wpi::array<double, (Degree + 1) / 2> x;
+ wpi::array<double, (Degree + 1) / 2> y;
};
/**
@@ -57,7 +55,7 @@
* @return The pose and curvature at that point.
*/
PoseWithCurvature GetPoint(double t) const {
- Eigen::Matrix<double, Degree + 1, 1> polynomialBases;
+ Eigen::Vector<double, Degree + 1> polynomialBases;
// Populate the polynomial bases
for (int i = 0; i <= Degree; i++) {
@@ -66,7 +64,7 @@
// This simply multiplies by the coefficients. We need to divide out t some
// n number of times where n is the derivative we want to take.
- Eigen::Matrix<double, 6, 1> combined = Coefficients() * polynomialBases;
+ Eigen::Vector<double, 6> combined = Coefficients() * polynomialBases;
double dx, dy, ddx, ddy;
@@ -111,9 +109,7 @@
* @return The vector.
*/
static Eigen::Vector2d ToVector(const Translation2d& translation) {
- return (Eigen::Vector2d() << translation.X().to<double>(),
- translation.Y().to<double>())
- .finished();
+ return Eigen::Vector2d{translation.X().value(), translation.Y().value()};
}
/**
diff --git a/wpimath/src/main/native/include/frc/spline/SplineHelper.h b/wpimath/src/main/native/include/frc/spline/SplineHelper.h
index e04fa45..90b6107 100644
--- a/wpimath/src/main/native/include/frc/spline/SplineHelper.h
+++ b/wpimath/src/main/native/include/frc/spline/SplineHelper.h
@@ -1,16 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include <array>
#include <utility>
#include <vector>
+#include <wpi/SymbolExports.h>
+#include <wpi/array.h>
+
#include "frc/spline/CubicHermiteSpline.h"
#include "frc/spline/QuinticHermiteSpline.h"
@@ -19,7 +18,7 @@
* Helper class that is used to generate cubic and quintic splines from user
* provided waypoints.
*/
-class SplineHelper {
+class WPILIB_DLLEXPORT SplineHelper {
public:
/**
* Returns 2 cubic control vectors from a set of exterior waypoints and
@@ -30,7 +29,7 @@
* @param end The ending pose.
* @return 2 cubic control vectors.
*/
- static std::array<Spline<3>::ControlVector, 2>
+ static wpi::array<Spline<3>::ControlVector, 2>
CubicControlVectorsFromWaypoints(
const Pose2d& start, const std::vector<Translation2d>& interiorWaypoints,
const Pose2d& end);
@@ -81,14 +80,14 @@
private:
static Spline<3>::ControlVector CubicControlVector(double scalar,
const Pose2d& point) {
- return {{point.X().to<double>(), scalar * point.Rotation().Cos()},
- {point.Y().to<double>(), scalar * point.Rotation().Sin()}};
+ return {{point.X().value(), scalar * point.Rotation().Cos()},
+ {point.Y().value(), scalar * point.Rotation().Sin()}};
}
static Spline<5>::ControlVector QuinticControlVector(double scalar,
const Pose2d& point) {
- return {{point.X().to<double>(), scalar * point.Rotation().Cos(), 0.0},
- {point.Y().to<double>(), scalar * point.Rotation().Sin(), 0.0}};
+ return {{point.X().value(), scalar * point.Rotation().Cos(), 0.0},
+ {point.Y().value(), scalar * point.Rotation().Sin(), 0.0}};
}
/**
diff --git a/wpimath/src/main/native/include/frc/spline/SplineParameterizer.h b/wpimath/src/main/native/include/frc/spline/SplineParameterizer.h
index 8e7079c..0720cd1 100644
--- a/wpimath/src/main/native/include/frc/spline/SplineParameterizer.h
+++ b/wpimath/src/main/native/include/frc/spline/SplineParameterizer.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
/*
* MIT License
@@ -36,7 +33,7 @@
#include <utility>
#include <vector>
-#include <wpi/Twine.h>
+#include <wpi/SymbolExports.h>
#include "frc/spline/Spline.h"
#include "units/angle.h"
@@ -49,7 +46,7 @@
/**
* Class used to parameterize a spline by its arc length.
*/
-class SplineParameterizer {
+class WPILIB_DLLEXPORT SplineParameterizer {
public:
using PoseWithCurvature = std::pair<Pose2d, units::curvature_t>;
diff --git a/wpimath/src/main/native/include/frc/system/Discretization.h b/wpimath/src/main/native/include/frc/system/Discretization.h
index 72d0226..722bb5f 100644
--- a/wpimath/src/main/native/include/frc/system/Discretization.h
+++ b/wpimath/src/main/native/include/frc/system/Discretization.h
@@ -1,22 +1,19 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include "Eigen/Core"
-#include "Eigen/src/LU/PartialPivLU.h"
#include "units/time.h"
-#include "unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h"
+#include "unsupported/Eigen/MatrixFunctions"
namespace frc {
/**
* Discretizes the given continuous A matrix.
*
+ * @tparam States Number of states.
* @param contA Continuous system matrix.
* @param dt Discretization timestep.
* @param discA Storage for discrete system matrix.
@@ -25,12 +22,14 @@
void DiscretizeA(const Eigen::Matrix<double, States, States>& contA,
units::second_t dt,
Eigen::Matrix<double, States, States>* discA) {
- *discA = (contA * dt.to<double>()).exp();
+ *discA = (contA * dt.value()).exp();
}
/**
* Discretizes the given continuous A and B matrices.
*
+ * @tparam States Number of states.
+ * @tparam Inputs Number of inputs.
* @param contA Continuous system matrix.
* @param contB Continuous input matrix.
* @param dt Discretization timestep.
@@ -46,8 +45,8 @@
// Matrices are blocked here to minimize matrix exponentiation calculations
Eigen::Matrix<double, States + Inputs, States + Inputs> Mcont;
Mcont.setZero();
- Mcont.template block<States, States>(0, 0) = contA * dt.to<double>();
- Mcont.template block<States, Inputs>(0, States) = contB * dt.to<double>();
+ Mcont.template block<States, States>(0, 0) = contA * dt.value();
+ Mcont.template block<States, Inputs>(0, States) = contB * dt.value();
// Discretize A and B with the given timestep
Eigen::Matrix<double, States + Inputs, States + Inputs> Mdisc = Mcont.exp();
@@ -58,6 +57,7 @@
/**
* Discretizes the given continuous A and Q matrices.
*
+ * @tparam States Number of states.
* @param contA Continuous system matrix.
* @param contQ Continuous process noise covariance matrix.
* @param dt Discretization timestep.
@@ -80,8 +80,7 @@
M.template block<States, States>(States, 0).setZero();
M.template block<States, States>(States, States) = contA.transpose();
- Eigen::Matrix<double, 2 * States, 2 * States> phi =
- (M * dt.to<double>()).exp();
+ Eigen::Matrix<double, 2 * States, 2 * States> phi = (M * dt.value()).exp();
// Phi12 = phi[0:States, States:2*States]
// Phi22 = phi[States:2*States, States:2*States]
@@ -110,6 +109,7 @@
* using a taylor series to several terms and still be substantially cheaper
* than taking the big exponential.
*
+ * @tparam States Number of states.
* @param contA Continuous system matrix.
* @param contQ Continuous process noise covariance matrix.
* @param dt Discretization timestep.
@@ -126,9 +126,9 @@
Eigen::Matrix<double, States, States> Q = (contQ + contQ.transpose()) / 2.0;
Eigen::Matrix<double, States, States> lastTerm = Q;
- double lastCoeff = dt.to<double>();
+ double lastCoeff = dt.value();
- // A^T^n
+ // Aᵀⁿ
Eigen::Matrix<double, States, States> Atn = contA.transpose();
Eigen::Matrix<double, States, States> phi12 = lastTerm * lastCoeff;
@@ -136,7 +136,7 @@
// i = 6 i.e. 5th order should be enough precision
for (int i = 2; i < 6; ++i) {
lastTerm = -contA * lastTerm + Q * Atn;
- lastCoeff *= dt.to<double>() / static_cast<double>(i);
+ lastCoeff *= dt.value() / static_cast<double>(i);
phi12 += lastTerm * lastCoeff;
@@ -154,13 +154,14 @@
* Returns a discretized version of the provided continuous measurement noise
* covariance matrix.
*
+ * @tparam Outputs Number of outputs.
* @param R Continuous measurement noise covariance matrix.
* @param dt Discretization timestep.
*/
template <int Outputs>
Eigen::Matrix<double, Outputs, Outputs> DiscretizeR(
const Eigen::Matrix<double, Outputs, Outputs>& R, units::second_t dt) {
- return R / dt.to<double>();
+ return R / dt.value();
}
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/system/LinearSystem.h b/wpimath/src/main/native/include/frc/system/LinearSystem.h
index 975fa0e..bd3ff40 100644
--- a/wpimath/src/main/native/include/frc/system/LinearSystem.h
+++ b/wpimath/src/main/native/include/frc/system/LinearSystem.h
@@ -1,14 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2018-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <algorithm>
#include <functional>
+#include <stdexcept>
#include "Eigen/Core"
#include "frc/StateSpaceUtil.h"
@@ -24,6 +22,10 @@
*
* For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ *
+ * @tparam States Number of states.
+ * @tparam Inputs Number of inputs.
+ * @tparam Outputs Number of outputs.
*/
template <int States, int Inputs, int Outputs>
class LinearSystem {
@@ -35,11 +37,33 @@
* @param B Input matrix.
* @param C Output matrix.
* @param D Feedthrough matrix.
+ * @throws std::domain_error if any matrix element isn't finite.
*/
LinearSystem(const Eigen::Matrix<double, States, States>& A,
const Eigen::Matrix<double, States, Inputs>& B,
const Eigen::Matrix<double, Outputs, States>& C,
const Eigen::Matrix<double, Outputs, Inputs>& D) {
+ if (!A.allFinite()) {
+ throw std::domain_error(
+ "Elements of A aren't finite. This is usually due to model "
+ "implementation errors.");
+ }
+ if (!B.allFinite()) {
+ throw std::domain_error(
+ "Elements of B aren't finite. This is usually due to model "
+ "implementation errors.");
+ }
+ if (!C.allFinite()) {
+ throw std::domain_error(
+ "Elements of C aren't finite. This is usually due to model "
+ "implementation errors.");
+ }
+ if (!D.allFinite()) {
+ throw std::domain_error(
+ "Elements of D aren't finite. This is usually due to model "
+ "implementation errors.");
+ }
+
m_A = A;
m_B = B;
m_C = C;
@@ -109,14 +133,13 @@
* This is used by state observers directly to run updates based on state
* estimate.
*
- * @param x The current state.
- * @param u The control input.
- * @param dt Timestep for model update.
+ * @param x The current state.
+ * @param clampedU The control input.
+ * @param dt Timestep for model update.
*/
- Eigen::Matrix<double, States, 1> CalculateX(
- const Eigen::Matrix<double, States, 1>& x,
- const Eigen::Matrix<double, Inputs, 1>& clampedU,
- units::second_t dt) const {
+ Eigen::Vector<double, States> CalculateX(
+ const Eigen::Vector<double, States>& x,
+ const Eigen::Vector<double, Inputs>& clampedU, units::second_t dt) const {
Eigen::Matrix<double, States, States> discA;
Eigen::Matrix<double, States, Inputs> discB;
DiscretizeAB<States, Inputs>(m_A, m_B, dt, &discA, &discB);
@@ -133,9 +156,9 @@
* @param x The current state.
* @param clampedU The control input.
*/
- Eigen::Matrix<double, Outputs, 1> CalculateY(
- const Eigen::Matrix<double, States, 1>& x,
- const Eigen::Matrix<double, Inputs, 1>& clampedU) const {
+ Eigen::Vector<double, Outputs> CalculateY(
+ const Eigen::Vector<double, States>& x,
+ const Eigen::Vector<double, Inputs>& clampedU) const {
return m_C * x + m_D * clampedU;
}
diff --git a/wpimath/src/main/native/include/frc/system/LinearSystemLoop.h b/wpimath/src/main/native/include/frc/system/LinearSystemLoop.h
index d5f25fb..9ee2ea2 100644
--- a/wpimath/src/main/native/include/frc/system/LinearSystemLoop.h
+++ b/wpimath/src/main/native/include/frc/system/LinearSystemLoop.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2018-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -18,8 +15,8 @@
namespace frc {
/**
- * Combines a plant, controller, and observer for controlling a mechanism with
- * full state feedback.
+ * Combines a controller, feedforward, and observer for controlling a mechanism
+ * with full state feedback.
*
* For everything in this file, "inputs" and "outputs" are defined from the
* perspective of the plant. This means U is an input and Y is an output
@@ -30,6 +27,10 @@
*
* For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
+ *
+ * @tparam States Number of states.
+ * @tparam Inputs Number of inputs.
+ * @tparam Outputs Number of outputs.
*/
template <int States, int Inputs, int Outputs>
class LinearSystemLoop {
@@ -52,9 +53,8 @@
units::volt_t maxVoltage, units::second_t dt)
: LinearSystemLoop(
plant, controller, observer,
- [=](Eigen::Matrix<double, Inputs, 1> u) {
- return frc::NormalizeInputVector<Inputs>(
- u, maxVoltage.template to<double>());
+ [=](const Eigen::Vector<double, Inputs>& u) {
+ return frc::NormalizeInputVector<Inputs>(u, maxVoltage.value());
},
dt) {}
@@ -73,21 +73,20 @@
LinearSystemLoop(LinearSystem<States, Inputs, Outputs>& plant,
LinearQuadraticRegulator<States, Inputs>& controller,
KalmanFilter<States, Inputs, Outputs>& observer,
- std::function<Eigen::Matrix<double, Inputs, 1>(
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, Inputs>(
+ const Eigen::Vector<double, Inputs>&)>
clampFunction,
units::second_t dt)
: LinearSystemLoop(
- plant, controller,
+ controller,
LinearPlantInversionFeedforward<States, Inputs>{plant, dt},
observer, clampFunction) {}
/**
- * Constructs a state-space loop with the given plant, controller, and
+ * Constructs a state-space loop with the given controller, feedforward and
* observer. By default, the initial reference is all zeros. Users should
- * call reset with the initial system state before enabling the loop.
+ * call reset with the initial system state.
*
- * @param plant State-space plant.
* @param controller State-space controller.
* @param feedforward Plant inversion feedforward.
* @param observer State-space observer.
@@ -95,48 +94,48 @@
* that the inputs are voltages.
*/
LinearSystemLoop(
- LinearSystem<States, Inputs, Outputs>& plant,
LinearQuadraticRegulator<States, Inputs>& controller,
const LinearPlantInversionFeedforward<States, Inputs>& feedforward,
KalmanFilter<States, Inputs, Outputs>& observer, units::volt_t maxVoltage)
- : LinearSystemLoop(plant, controller, feedforward, observer,
- [=](Eigen::Matrix<double, Inputs, 1> u) {
+ : LinearSystemLoop(controller, feedforward, observer,
+ [=](const Eigen::Vector<double, Inputs>& u) {
return frc::NormalizeInputVector<Inputs>(
- u, maxVoltage.template to<double>());
+ u, maxVoltage.value());
}) {}
/**
- * Constructs a state-space loop with the given plant, controller, and
- * observer.
+ * Constructs a state-space loop with the given controller, feedforward,
+ * observer and clamp function. By default, the initial reference is all
+ * zeros. Users should call reset with the initial system state.
*
- * @param plant State-space plant.
* @param controller State-space controller.
* @param feedforward Plant-inversion feedforward.
* @param observer State-space observer.
* @param clampFunction The function used to clamp the input vector.
*/
LinearSystemLoop(
- LinearSystem<States, Inputs, Outputs>& plant,
LinearQuadraticRegulator<States, Inputs>& controller,
const LinearPlantInversionFeedforward<States, Inputs>& feedforward,
KalmanFilter<States, Inputs, Outputs>& observer,
- std::function<Eigen::Matrix<double, Inputs, 1>(
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<
+ Eigen::Vector<double, Inputs>(const Eigen::Vector<double, Inputs>&)>
clampFunction)
- : m_plant(plant),
- m_controller(controller),
+ : m_controller(&controller),
m_feedforward(feedforward),
- m_observer(observer),
+ m_observer(&observer),
m_clampFunc(clampFunction) {
m_nextR.setZero();
Reset(m_nextR);
}
+ LinearSystemLoop(LinearSystemLoop&&) = default;
+ LinearSystemLoop& operator=(LinearSystemLoop&&) = default;
+
/**
* Returns the observer's state estimate x-hat.
*/
- const Eigen::Matrix<double, States, 1>& Xhat() const {
- return m_observer.Xhat();
+ const Eigen::Vector<double, States>& Xhat() const {
+ return m_observer->Xhat();
}
/**
@@ -144,12 +143,12 @@
*
* @param i Row of x-hat.
*/
- double Xhat(int i) const { return m_observer.Xhat(i); }
+ double Xhat(int i) const { return m_observer->Xhat(i); }
/**
* Returns the controller's next reference r.
*/
- const Eigen::Matrix<double, States, 1>& NextR() const { return m_nextR; }
+ const Eigen::Vector<double, States>& NextR() const { return m_nextR; }
/**
* Returns an element of the controller's next reference r.
@@ -161,8 +160,8 @@
/**
* Returns the controller's calculated control input u.
*/
- Eigen::Matrix<double, Inputs, 1> U() const {
- return ClampInput(m_controller.U() + m_feedforward.Uff());
+ Eigen::Vector<double, Inputs> U() const {
+ return ClampInput(m_controller->U() + m_feedforward.Uff());
}
/**
@@ -177,8 +176,8 @@
*
* @param xHat The initial state estimate x-hat.
*/
- void SetXhat(const Eigen::Matrix<double, States, 1>& xHat) {
- m_observer.SetXhat(xHat);
+ void SetXhat(const Eigen::Vector<double, States>& xHat) {
+ m_observer->SetXhat(xHat);
}
/**
@@ -187,27 +186,20 @@
* @param i Row of x-hat.
* @param value Value for element of x-hat.
*/
- void SetXhat(int i, double value) { m_observer.SetXhat(i, value); }
+ void SetXhat(int i, double value) { m_observer->SetXhat(i, value); }
/**
* Set the next reference r.
*
* @param nextR Next reference.
*/
- void SetNextR(const Eigen::Matrix<double, States, 1>& nextR) {
- m_nextR = nextR;
- }
-
- /**
- * Return the plant used internally.
- */
- const LinearSystem<States, Inputs, Outputs>& Plant() const { return m_plant; }
+ void SetNextR(const Eigen::Vector<double, States>& nextR) { m_nextR = nextR; }
/**
* Return the controller used internally.
*/
const LinearQuadraticRegulator<States, Inputs>& Controller() const {
- return m_controller;
+ return *m_controller;
}
/**
@@ -233,18 +225,18 @@
*
* @param initialState The initial state.
*/
- void Reset(Eigen::Matrix<double, States, 1> initialState) {
+ void Reset(const Eigen::Vector<double, States>& initialState) {
m_nextR.setZero();
- m_controller.Reset();
+ m_controller->Reset();
m_feedforward.Reset(initialState);
- m_observer.SetXhat(initialState);
+ m_observer->SetXhat(initialState);
}
/**
* Returns difference between reference r and current state x-hat.
*/
- const Eigen::Matrix<double, States, 1> Error() const {
- return m_controller.R() - m_observer.Xhat();
+ Eigen::Vector<double, States> Error() const {
+ return m_controller->R() - m_observer->Xhat();
}
/**
@@ -252,8 +244,8 @@
*
* @param y Measurement vector.
*/
- void Correct(const Eigen::Matrix<double, Outputs, 1>& y) {
- m_observer.Correct(U(), y);
+ void Correct(const Eigen::Vector<double, Outputs>& y) {
+ m_observer->Correct(U(), y);
}
/**
@@ -266,10 +258,10 @@
* @param dt Timestep for model update.
*/
void Predict(units::second_t dt) {
- Eigen::Matrix<double, Inputs, 1> u =
- ClampInput(m_controller.Calculate(m_observer.Xhat(), m_nextR) +
+ Eigen::Vector<double, Inputs> u =
+ ClampInput(m_controller->Calculate(m_observer->Xhat(), m_nextR) +
m_feedforward.Calculate(m_nextR));
- m_observer.Predict(u, dt);
+ m_observer->Predict(u, dt);
}
/**
@@ -278,26 +270,25 @@
* @param u Input vector to clamp.
* @return Clamped input vector.
*/
- Eigen::Matrix<double, Inputs, 1> ClampInput(
- const Eigen::Matrix<double, Inputs, 1>& u) const {
+ Eigen::Vector<double, Inputs> ClampInput(
+ const Eigen::Vector<double, Inputs>& u) const {
return m_clampFunc(u);
}
protected:
- LinearSystem<States, Inputs, Outputs>& m_plant;
- LinearQuadraticRegulator<States, Inputs>& m_controller;
+ LinearQuadraticRegulator<States, Inputs>* m_controller;
LinearPlantInversionFeedforward<States, Inputs> m_feedforward;
- KalmanFilter<States, Inputs, Outputs>& m_observer;
+ KalmanFilter<States, Inputs, Outputs>* m_observer;
/**
* Clamping function.
*/
- std::function<Eigen::Matrix<double, Inputs, 1>(
- const Eigen::Matrix<double, Inputs, 1>&)>
+ std::function<Eigen::Vector<double, Inputs>(
+ const Eigen::Vector<double, Inputs>&)>
m_clampFunc;
// Reference to go to in the next cycle (used by feedforward controller).
- Eigen::Matrix<double, States, 1> m_nextR;
+ Eigen::Vector<double, States> m_nextR;
// These are accessible from non-templated subclasses.
static constexpr int kStates = States;
diff --git a/wpimath/src/main/native/include/frc/system/NumericalIntegration.h b/wpimath/src/main/native/include/frc/system/NumericalIntegration.h
new file mode 100644
index 0000000..68d047f
--- /dev/null
+++ b/wpimath/src/main/native/include/frc/system/NumericalIntegration.h
@@ -0,0 +1,209 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <frc/StateSpaceUtil.h>
+
+#include <algorithm>
+#include <array>
+
+#include "Eigen/Core"
+#include "units/time.h"
+
+namespace frc {
+
+/**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
+ *
+ * @param f The function to integrate. It must take one argument x.
+ * @param x The initial value of x.
+ * @param dt The time over which to integrate.
+ */
+template <typename F, typename T>
+T RK4(F&& f, T x, units::second_t dt) {
+ const auto h = dt.value();
+
+ T k1 = f(x);
+ T k2 = f(x + h * 0.5 * k1);
+ T k3 = f(x + h * 0.5 * k2);
+ T k4 = f(x + h * k3);
+
+ return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
+}
+
+/**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
+ *
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dt The time over which to integrate.
+ */
+template <typename F, typename T, typename U>
+T RK4(F&& f, T x, U u, units::second_t dt) {
+ const auto h = dt.value();
+
+ T k1 = f(x, u);
+ T k2 = f(x + h * 0.5 * k1, u);
+ T k3 = f(x + h * 0.5 * k2, u);
+ T k4 = f(x + h * k3, u);
+
+ return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
+}
+
+/**
+ * Performs adaptive RKF45 integration of dx/dt = f(x, u) for dt, as described
+ * in https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
+ *
+ * @param f The function to integrate. It must take two arguments x and
+ * u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dt The time over which to integrate.
+ * @param maxError The maximum acceptable truncation error. Usually a small
+ * number like 1e-6.
+ */
+template <typename F, typename T, typename U>
+T RKF45(F&& f, T x, U u, units::second_t dt, double maxError = 1e-6) {
+ // See
+ // https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
+ // for the Butcher tableau the following arrays came from.
+ constexpr int kDim = 6;
+
+ // clang-format off
+ constexpr double A[kDim - 1][kDim - 1]{
+ { 1.0 / 4.0},
+ { 3.0 / 32.0, 9.0 / 32.0},
+ {1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0},
+ { 439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0},
+ { -8.0 / 27.0, 2.0, -3544.0 / 2565.0, 1859.0 / 4104.0, -11.0 / 40.0}};
+ // clang-format on
+
+ constexpr std::array<double, kDim> b1{16.0 / 135.0, 0.0,
+ 6656.0 / 12825.0, 28561.0 / 56430.0,
+ -9.0 / 50.0, 2.0 / 55.0};
+ constexpr std::array<double, kDim> b2{
+ 25.0 / 216.0, 0.0, 1408.0 / 2565.0, 2197.0 / 4104.0, -1.0 / 5.0, 0.0};
+
+ T newX;
+ double truncationError;
+
+ double dtElapsed = 0.0;
+ double h = dt.value();
+
+ // Loop until we've gotten to our desired dt
+ while (dtElapsed < dt.value()) {
+ do {
+ // Only allow us to advance up to the dt remaining
+ h = std::min(h, dt.value() - dtElapsed);
+
+ // Notice how the derivative in the Wikipedia notation is dy/dx.
+ // That means their y is our x and their x is our t
+ // clang-format off
+ T k1 = f(x, u);
+ T k2 = f(x + h * (A[0][0] * k1), u);
+ T k3 = f(x + h * (A[1][0] * k1 + A[1][1] * k2), u);
+ T k4 = f(x + h * (A[2][0] * k1 + A[2][1] * k2 + A[2][2] * k3), u);
+ T k5 = f(x + h * (A[3][0] * k1 + A[3][1] * k2 + A[3][2] * k3 + A[3][3] * k4), u);
+ T k6 = f(x + h * (A[4][0] * k1 + A[4][1] * k2 + A[4][2] * k3 + A[4][3] * k4 + A[4][4] * k5), u);
+ // clang-format on
+
+ newX = x + h * (b1[0] * k1 + b1[1] * k2 + b1[2] * k3 + b1[3] * k4 +
+ b1[4] * k5 + b1[5] * k6);
+ truncationError = (h * ((b1[0] - b2[0]) * k1 + (b1[1] - b2[1]) * k2 +
+ (b1[2] - b2[2]) * k3 + (b1[3] - b2[3]) * k4 +
+ (b1[4] - b2[4]) * k5 + (b1[5] - b2[5]) * k6))
+ .norm();
+
+ h *= 0.9 * std::pow(maxError / truncationError, 1.0 / 5.0);
+ } while (truncationError > maxError);
+
+ dtElapsed += h;
+ x = newX;
+ }
+
+ return x;
+}
+
+/**
+ * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt.
+ *
+ * @param f The function to integrate. It must take two arguments x and
+ * u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dt The time over which to integrate.
+ * @param maxError The maximum acceptable truncation error. Usually a small
+ * number like 1e-6.
+ */
+template <typename F, typename T, typename U>
+T RKDP(F&& f, T x, U u, units::second_t dt, double maxError = 1e-6) {
+ // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the
+ // Butcher tableau the following arrays came from.
+
+ constexpr int kDim = 7;
+
+ // clang-format off
+ constexpr double A[kDim - 1][kDim - 1]{
+ { 1.0 / 5.0},
+ { 3.0 / 40.0, 9.0 / 40.0},
+ { 44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0},
+ {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0},
+ { 9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0},
+ { 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0}};
+ // clang-format on
+
+ constexpr std::array<double, kDim> b1{
+ 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0,
+ 11.0 / 84.0, 0.0};
+ constexpr std::array<double, kDim> b2{5179.0 / 57600.0, 0.0,
+ 7571.0 / 16695.0, 393.0 / 640.0,
+ -92097.0 / 339200.0, 187.0 / 2100.0,
+ 1.0 / 40.0};
+
+ T newX;
+ double truncationError;
+
+ double dtElapsed = 0.0;
+ double h = dt.value();
+
+ // Loop until we've gotten to our desired dt
+ while (dtElapsed < dt.value()) {
+ do {
+ // Only allow us to advance up to the dt remaining
+ h = std::min(h, dt.value() - dtElapsed);
+
+ // clang-format off
+ T k1 = f(x, u);
+ T k2 = f(x + h * (A[0][0] * k1), u);
+ T k3 = f(x + h * (A[1][0] * k1 + A[1][1] * k2), u);
+ T k4 = f(x + h * (A[2][0] * k1 + A[2][1] * k2 + A[2][2] * k3), u);
+ T k5 = f(x + h * (A[3][0] * k1 + A[3][1] * k2 + A[3][2] * k3 + A[3][3] * k4), u);
+ T k6 = f(x + h * (A[4][0] * k1 + A[4][1] * k2 + A[4][2] * k3 + A[4][3] * k4 + A[4][4] * k5), u);
+ // clang-format on
+
+ // Since the final row of A and the array b1 have the same coefficients
+ // and k7 has no effect on newX, we can reuse the calculation.
+ newX = x + h * (A[5][0] * k1 + A[5][1] * k2 + A[5][2] * k3 +
+ A[5][3] * k4 + A[5][4] * k5 + A[5][5] * k6);
+ T k7 = f(newX, u);
+
+ truncationError = (h * ((b1[0] - b2[0]) * k1 + (b1[1] - b2[1]) * k2 +
+ (b1[2] - b2[2]) * k3 + (b1[3] - b2[3]) * k4 +
+ (b1[4] - b2[4]) * k5 + (b1[5] - b2[5]) * k6 +
+ (b1[6] - b2[6]) * k7))
+ .norm();
+
+ h *= 0.9 * std::pow(maxError / truncationError, 1.0 / 5.0);
+ } while (truncationError > maxError);
+
+ dtElapsed += h;
+ x = newX;
+ }
+
+ return x;
+}
+
+} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/system/NumericalJacobian.h b/wpimath/src/main/native/include/frc/system/NumericalJacobian.h
index cbd6943..5f1bc78 100644
--- a/wpimath/src/main/native/include/frc/system/NumericalJacobian.h
+++ b/wpimath/src/main/native/include/frc/system/NumericalJacobian.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -20,16 +17,16 @@
* @param x Vector argument.
*/
template <int Rows, int Cols, typename F>
-auto NumericalJacobian(F&& f, const Eigen::Matrix<double, Cols, 1>& x) {
+auto NumericalJacobian(F&& f, const Eigen::Vector<double, Cols>& x) {
constexpr double kEpsilon = 1e-5;
Eigen::Matrix<double, Rows, Cols> result;
result.setZero();
// It's more expensive, but +- epsilon will be more accurate
for (int i = 0; i < Cols; ++i) {
- Eigen::Matrix<double, Cols, 1> dX_plus = x;
+ Eigen::Vector<double, Cols> dX_plus = x;
dX_plus(i) += kEpsilon;
- Eigen::Matrix<double, Cols, 1> dX_minus = x;
+ Eigen::Vector<double, Cols> dX_minus = x;
dX_minus(i) -= kEpsilon;
result.col(i) = (f(dX_plus) - f(dX_minus)) / (kEpsilon * 2.0);
}
@@ -44,17 +41,19 @@
* @tparam States Number of rows in x.
* @tparam Inputs Number of rows in u.
* @tparam F Function object type.
- * @tparam Args... Remaining arguments to f(x, u, ...).
+ * @tparam Args... Types of remaining arguments to f(x, u, ...).
* @param f Vector-valued function from which to compute Jacobian.
* @param x State vector.
* @param u Input vector.
+ * @param args Remaining arguments to f(x, u, ...).
*/
template <int Rows, int States, int Inputs, typename F, typename... Args>
-auto NumericalJacobianX(F&& f, const Eigen::Matrix<double, States, 1>& x,
- const Eigen::Matrix<double, Inputs, 1>& u,
+auto NumericalJacobianX(F&& f, const Eigen::Vector<double, States>& x,
+ const Eigen::Vector<double, Inputs>& u,
Args&&... args) {
return NumericalJacobian<Rows, States>(
- [&](Eigen::Matrix<double, States, 1> x) { return f(x, u, args...); }, x);
+ [&](const Eigen::Vector<double, States>& x) { return f(x, u, args...); },
+ x);
}
/**
@@ -64,17 +63,19 @@
* @tparam States Number of rows in x.
* @tparam Inputs Number of rows in u.
* @tparam F Function object type.
- * @tparam Args... Remaining arguments to f(x, u, ...).
+ * @tparam Args... Types of remaining arguments to f(x, u, ...).
* @param f Vector-valued function from which to compute Jacobian.
* @param x State vector.
* @param u Input vector.
+ * @param args Remaining arguments to f(x, u, ...).
*/
template <int Rows, int States, int Inputs, typename F, typename... Args>
-auto NumericalJacobianU(F&& f, const Eigen::Matrix<double, States, 1>& x,
- const Eigen::Matrix<double, Inputs, 1>& u,
+auto NumericalJacobianU(F&& f, const Eigen::Vector<double, States>& x,
+ const Eigen::Vector<double, Inputs>& u,
Args&&... args) {
return NumericalJacobian<Rows, Inputs>(
- [&](Eigen::Matrix<double, Inputs, 1> u) { return f(x, u, args...); }, u);
+ [&](const Eigen::Vector<double, Inputs>& u) { return f(x, u, args...); },
+ u);
}
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/system/RungeKutta.h b/wpimath/src/main/native/include/frc/system/RungeKutta.h
deleted file mode 100644
index a83cafc..0000000
--- a/wpimath/src/main/native/include/frc/system/RungeKutta.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-#pragma once
-
-#include "Eigen/Core"
-#include "units/time.h"
-
-namespace frc {
-
-/**
- * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
- *
- * @param f The function to integrate. It must take one argument x.
- * @param x The initial value of x.
- * @param dt The time over which to integrate.
- */
-template <typename F, typename T>
-T RungeKutta(F&& f, T x, units::second_t dt) {
- const auto halfDt = 0.5 * dt;
- T k1 = f(x);
- T k2 = f(x + k1 * halfDt.to<double>());
- T k3 = f(x + k2 * halfDt.to<double>());
- T k4 = f(x + k3 * dt.to<double>());
- return x + dt.to<double>() / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
-}
-
-/**
- * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
- *
- * @param f The function to integrate. It must take two arguments x and u.
- * @param x The initial value of x.
- * @param u The value u held constant over the integration period.
- * @param dt The time over which to integrate.
- */
-template <typename F, typename T, typename U>
-T RungeKutta(F&& f, T x, U u, units::second_t dt) {
- const auto halfDt = 0.5 * dt;
- T k1 = f(x, u);
- T k2 = f(x + k1 * halfDt.to<double>(), u);
- T k3 = f(x + k2 * halfDt.to<double>(), u);
- T k4 = f(x + k3 * dt.to<double>(), u);
- return x + dt.to<double>() / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
-}
-
-/**
- * Performs 4th order Runge-Kutta integration of dx/dt = f(t, x) for dt.
- *
- * @param f The function to integrate. It must take two arguments x and t.
- * @param x The initial value of x.
- * @param t The initial value of t.
- * @param dt The time over which to integrate.
- */
-template <typename F, typename T>
-T RungeKuttaTimeVarying(F&& f, T x, units::second_t t, units::second_t dt) {
- const auto halfDt = 0.5 * dt;
- T k1 = f(t, x);
- T k2 = f(t + halfDt, x + k1 / 2.0);
- T k3 = f(t + halfDt, x + k2 / 2.0);
- T k4 = f(t + dt, x + k3);
-
- return x + dt.to<double>() / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
-}
-
-} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/system/plant/DCMotor.h b/wpimath/src/main/native/include/frc/system/plant/DCMotor.h
index d2a2ba8..a519b0e 100644
--- a/wpimath/src/main/native/include/frc/system/plant/DCMotor.h
+++ b/wpimath/src/main/native/include/frc/system/plant/DCMotor.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "units/angular_velocity.h"
#include "units/current.h"
#include "units/impedance.h"
@@ -18,7 +17,7 @@
/**
* Holds the constants for a DC motor.
*/
-class DCMotor {
+class WPILIB_DLLEXPORT DCMotor {
public:
using radians_per_second_per_volt_t =
units::unit_t<units::compound_unit<units::radians_per_second,
@@ -46,7 +45,7 @@
* Constructs a DC motor.
*
* @param nominalVoltage Voltage at which the motor constants were measured.
- * @param stallTorque Current draw when stalled.
+ * @param stallTorque Torque when stalled.
* @param stallCurrent Current draw when stalled.
* @param freeCurrent Current draw under no load.
* @param freeSpeed Angular velocity under no load.
@@ -58,12 +57,12 @@
units::radians_per_second_t freeSpeed, int numMotors = 1)
: nominalVoltage(nominalVoltage),
stallTorque(stallTorque * numMotors),
- stallCurrent(stallCurrent),
- freeCurrent(freeCurrent),
+ stallCurrent(stallCurrent * numMotors),
+ freeCurrent(freeCurrent * numMotors),
freeSpeed(freeSpeed),
- R(nominalVoltage / stallCurrent),
- Kv(freeSpeed / (nominalVoltage - R * freeCurrent)),
- Kt(stallTorque * numMotors / stallCurrent) {}
+ R(nominalVoltage / this->stallCurrent),
+ Kv(freeSpeed / (nominalVoltage - R * this->freeCurrent)),
+ Kt(this->stallTorque / this->stallCurrent) {}
/**
* Returns current drawn by motor with given speed and input voltage.
@@ -152,6 +151,14 @@
static constexpr DCMotor Falcon500(int numMotors = 1) {
return DCMotor(12_V, 4.69_Nm, 257_A, 1.5_A, 6380_rpm, numMotors);
}
+
+ /**
+ * Return a gearbox of Romi/TI_RSLK MAX motors.
+ */
+ static constexpr DCMotor RomiBuiltIn(int numMotors = 1) {
+ // From https://www.pololu.com/product/1520/specs
+ return DCMotor(4.5_V, 0.1765_Nm, 1.25_A, 0.13_A, 150_rpm, numMotors);
+ }
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/system/plant/LinearSystemId.h b/wpimath/src/main/native/include/frc/system/plant/LinearSystemId.h
index b712460..c0f4506 100644
--- a/wpimath/src/main/native/include/frc/system/plant/LinearSystemId.h
+++ b/wpimath/src/main/native/include/frc/system/plant/LinearSystemId.h
@@ -1,24 +1,25 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
-#include "frc/StateSpaceUtil.h"
+#include <stdexcept>
+
+#include <wpi/SymbolExports.h>
+
#include "frc/system/LinearSystem.h"
#include "frc/system/plant/DCMotor.h"
#include "units/acceleration.h"
#include "units/angular_acceleration.h"
#include "units/angular_velocity.h"
+#include "units/length.h"
#include "units/moment_of_inertia.h"
#include "units/velocity.h"
#include "units/voltage.h"
namespace frc {
-class LinearSystemId {
+class WPILIB_DLLEXPORT LinearSystemId {
public:
template <typename Distance>
using Velocity_t = units::unit_t<
@@ -40,19 +41,30 @@
* @param m Carriage mass.
* @param r Pulley radius.
* @param G Gear ratio from motor to carriage.
+ * @throws std::domain_error if m <= 0, r <= 0, or G <= 0.
*/
static LinearSystem<2, 1, 1> ElevatorSystem(DCMotor motor,
units::kilogram_t m,
units::meter_t r, double G) {
- auto A = frc::MakeMatrix<2, 2>(
- 0.0, 1.0, 0.0,
- (-std::pow(G, 2) * motor.Kt /
- (motor.R * units::math::pow<2>(r) * m * motor.Kv))
- .to<double>());
- auto B = frc::MakeMatrix<2, 1>(
- 0.0, (G * motor.Kt / (motor.R * r * m)).to<double>());
- auto C = frc::MakeMatrix<1, 2>(1.0, 0.0);
- auto D = frc::MakeMatrix<1, 1>(0.0);
+ if (m <= 0_kg) {
+ throw std::domain_error("m must be greater than zero.");
+ }
+ if (r <= 0_m) {
+ throw std::domain_error("r must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw std::domain_error("G must be greater than zero.");
+ }
+
+ Eigen::Matrix<double, 2, 2> A{
+ {0.0, 1.0},
+ {0.0, (-std::pow(G, 2) * motor.Kt /
+ (motor.R * units::math::pow<2>(r) * m * motor.Kv))
+ .value()}};
+ Eigen::Matrix<double, 2, 1> B{0.0,
+ (G * motor.Kt / (motor.R * r * m)).value()};
+ Eigen::Matrix<double, 1, 2> C{1.0, 0.0};
+ Eigen::Matrix<double, 1, 1> D{0.0};
return LinearSystem<2, 1, 1>(A, B, C, D);
}
@@ -67,16 +79,23 @@
* @param motor Instance of DCMotor.
* @param J Moment of inertia.
* @param G Gear ratio from motor to carriage.
+ * @throws std::domain_error if J <= 0 or G <= 0.
*/
static LinearSystem<2, 1, 1> SingleJointedArmSystem(
DCMotor motor, units::kilogram_square_meter_t J, double G) {
- auto A = frc::MakeMatrix<2, 2>(
- 0.0, 1.0, 0.0,
- (-std::pow(G, 2) * motor.Kt / (motor.Kv * motor.R * J)).to<double>());
- auto B =
- frc::MakeMatrix<2, 1>(0.0, (G * motor.Kt / (motor.R * J)).to<double>());
- auto C = frc::MakeMatrix<1, 2>(1.0, 0.0);
- auto D = frc::MakeMatrix<1, 1>(0.0);
+ if (J <= 0_kg_sq_m) {
+ throw std::domain_error("J must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw std::domain_error("G must be greater than zero.");
+ }
+
+ Eigen::Matrix<double, 2, 2> A{
+ {0.0, 1.0},
+ {0.0, (-std::pow(G, 2) * motor.Kt / (motor.Kv * motor.R * J)).value()}};
+ Eigen::Matrix<double, 2, 1> B{0.0, (G * motor.Kt / (motor.R * J)).value()};
+ Eigen::Matrix<double, 1, 2> C{1.0, 0.0};
+ Eigen::Matrix<double, 1, 1> D{0.0};
return LinearSystem<2, 1, 1>(A, B, C, D);
}
@@ -100,6 +119,7 @@
*
* @param kV The velocity gain, in volt seconds per distance.
* @param kA The acceleration gain, in volt seconds^2 per distance.
+ * @throws std::domain_error if kV <= 0 or kA <= 0.
*/
template <typename Distance, typename = std::enable_if_t<
std::is_same_v<units::meter, Distance> ||
@@ -107,11 +127,17 @@
static LinearSystem<1, 1, 1> IdentifyVelocitySystem(
decltype(1_V / Velocity_t<Distance>(1)) kV,
decltype(1_V / Acceleration_t<Distance>(1)) kA) {
- auto A = frc::MakeMatrix<1, 1>(-kV.template to<double>() /
- kA.template to<double>());
- auto B = frc::MakeMatrix<1, 1>(1.0 / kA.template to<double>());
- auto C = frc::MakeMatrix<1, 1>(1.0);
- auto D = frc::MakeMatrix<1, 1>(0.0);
+ if (kV <= decltype(kV){0}) {
+ throw std::domain_error("Kv must be greater than zero.");
+ }
+ if (kA <= decltype(kA){0}) {
+ throw std::domain_error("Ka must be greater than zero.");
+ }
+
+ Eigen::Matrix<double, 1, 1> A{-kV.value() / kA.value()};
+ Eigen::Matrix<double, 1, 1> B{1.0 / kA.value()};
+ Eigen::Matrix<double, 1, 1> C{1.0};
+ Eigen::Matrix<double, 1, 1> D{0.0};
return LinearSystem<1, 1, 1>(A, B, C, D);
}
@@ -135,6 +161,7 @@
*
* @param kV The velocity gain, in volt seconds per distance.
* @param kA The acceleration gain, in volt seconds^2 per distance.
+ * @throws std::domain_error if kV <= 0 or kA <= 0.
*/
template <typename Distance, typename = std::enable_if_t<
std::is_same_v<units::meter, Distance> ||
@@ -142,11 +169,17 @@
static LinearSystem<2, 1, 1> IdentifyPositionSystem(
decltype(1_V / Velocity_t<Distance>(1)) kV,
decltype(1_V / Acceleration_t<Distance>(1)) kA) {
- auto A = frc::MakeMatrix<2, 2>(
- 0.0, 1.0, 0.0, -kV.template to<double>() / kA.template to<double>());
- auto B = frc::MakeMatrix<2, 1>(0.0, 1.0 / kA.template to<double>());
- auto C = frc::MakeMatrix<1, 2>(1.0, 0.0);
- auto D = frc::MakeMatrix<1, 1>(0.0);
+ if (kV <= decltype(kV){0}) {
+ throw std::domain_error("Kv must be greater than zero.");
+ }
+ if (kA <= decltype(kA){0}) {
+ throw std::domain_error("Ka must be greater than zero.");
+ }
+
+ Eigen::Matrix<double, 2, 2> A{{0.0, 1.0}, {0.0, -kV.value() / kA.value()}};
+ Eigen::Matrix<double, 2, 1> B{0.0, 1.0 / kA.value()};
+ Eigen::Matrix<double, 1, 2> C{1.0, 0.0};
+ Eigen::Matrix<double, 1, 1> D{0.0};
return LinearSystem<2, 1, 1>(A, B, C, D);
}
@@ -159,31 +192,101 @@
* Inputs: [[left voltage], [right voltage]]
* Outputs: [[left velocity], [right velocity]]
*
- * @param kVlinear The linear velocity gain, in volt seconds per distance.
- * @param kAlinear The linear acceleration gain, in volt seconds^2 per
- * distance.
- * @param kVangular The angular velocity gain, in volt seconds per angle.
- * @param kAangular The angular acceleration gain, in volt seconds^2 per
- * angle.
+ * @param kVlinear The linear velocity gain in volts per (meter per second).
+ * @param kAlinear The linear acceleration gain in volts per (meter per
+ * second squared).
+ * @param kVangular The angular velocity gain in volts per (meter per second).
+ * @param kAangular The angular acceleration gain in volts per (meter per
+ * second squared).
+ * @throws domain_error if kVlinear <= 0, kAlinear <= 0, kVangular <= 0,
+ * or kAangular <= 0.
+ */
+ static LinearSystem<2, 2, 2> IdentifyDrivetrainSystem(
+ decltype(1_V / 1_mps) kVlinear, decltype(1_V / 1_mps_sq) kAlinear,
+ decltype(1_V / 1_mps) kVangular, decltype(1_V / 1_mps_sq) kAangular) {
+ if (kVlinear <= decltype(kVlinear){0}) {
+ throw std::domain_error("Kv,linear must be greater than zero.");
+ }
+ if (kAlinear <= decltype(kAlinear){0}) {
+ throw std::domain_error("Ka,linear must be greater than zero.");
+ }
+ if (kVangular <= decltype(kVangular){0}) {
+ throw std::domain_error("Kv,angular must be greater than zero.");
+ }
+ if (kAangular <= decltype(kAangular){0}) {
+ throw std::domain_error("Ka,angular must be greater than zero.");
+ }
+
+ double A1 = -(kVlinear.value() / kAlinear.value() +
+ kVangular.value() / kAangular.value());
+ double A2 = -(kVlinear.value() / kAlinear.value() -
+ kVangular.value() / kAangular.value());
+ double B1 = 1.0 / kAlinear.value() + 1.0 / kAangular.value();
+ double B2 = 1.0 / kAlinear.value() - 1.0 / kAangular.value();
+
+ Eigen::Matrix<double, 2, 2> A =
+ 0.5 * Eigen::Matrix<double, 2, 2>{{A1, A2}, {A2, A1}};
+ Eigen::Matrix<double, 2, 2> B =
+ 0.5 * Eigen::Matrix<double, 2, 2>{{B1, B2}, {B2, B1}};
+ Eigen::Matrix<double, 2, 2> C{{1.0, 0.0}, {0.0, 1.0}};
+ Eigen::Matrix<double, 2, 2> D{{0.0, 0.0}, {0.0, 0.0}};
+
+ return LinearSystem<2, 2, 2>(A, B, C, D);
+ }
+
+ /**
+ * Constructs the state-space model for a 2 DOF drivetrain velocity system
+ * from system identification data.
+ *
+ * States: [[left velocity], [right velocity]]
+ * Inputs: [[left voltage], [right voltage]]
+ * Outputs: [[left velocity], [right velocity]]
+ *
+ * @param kVlinear The linear velocity gain in volts per (meter per second).
+ * @param kAlinear The linear acceleration gain in volts per (meter per
+ * second squared).
+ * @param kVangular The angular velocity gain in volts per (radian per
+ * second).
+ * @param kAangular The angular acceleration gain in volts per (radian per
+ * second squared).
+ * @param trackwidth The width of the drivetrain.
+ * @throws domain_error if kVlinear <= 0, kAlinear <= 0, kVangular <= 0,
+ * kAangular <= 0, or trackwidth <= 0.
*/
static LinearSystem<2, 2, 2> IdentifyDrivetrainSystem(
decltype(1_V / 1_mps) kVlinear, decltype(1_V / 1_mps_sq) kAlinear,
decltype(1_V / 1_rad_per_s) kVangular,
- decltype(1_V / 1_rad_per_s_sq) kAangular) {
- double c = 0.5 / (kAlinear.to<double>() * kAangular.to<double>());
- double A1 = c * (-kAlinear.to<double>() * kVangular.to<double>() -
- kVlinear.to<double>() * kAangular.to<double>());
- double A2 = c * (kAlinear.to<double>() * kVangular.to<double>() -
- kVlinear.to<double>() * kAangular.to<double>());
- double B1 = c * (kAlinear.to<double>() + kAangular.to<double>());
- double B2 = c * (kAangular.to<double>() - kAlinear.to<double>());
+ decltype(1_V / 1_rad_per_s_sq) kAangular, units::meter_t trackwidth) {
+ if (kVlinear <= decltype(kVlinear){0}) {
+ throw std::domain_error("Kv,linear must be greater than zero.");
+ }
+ if (kAlinear <= decltype(kAlinear){0}) {
+ throw std::domain_error("Ka,linear must be greater than zero.");
+ }
+ if (kVangular <= decltype(kVangular){0}) {
+ throw std::domain_error("Kv,angular must be greater than zero.");
+ }
+ if (kAangular <= decltype(kAangular){0}) {
+ throw std::domain_error("Ka,angular must be greater than zero.");
+ }
+ if (trackwidth <= 0_m) {
+ throw std::domain_error("r must be greater than zero.");
+ }
- auto A = frc::MakeMatrix<2, 2>(A1, A2, A2, A1);
- auto B = frc::MakeMatrix<2, 2>(B1, B2, B2, B1);
- auto C = frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 1.0);
- auto D = frc::MakeMatrix<2, 2>(0.0, 0.0, 0.0, 0.0);
-
- return LinearSystem<2, 2, 2>(A, B, C, D);
+ // We want to find a factor to include in Kv,angular that will convert
+ // `u = Kv,angular omega` to `u = Kv,angular v`.
+ //
+ // v = omega r
+ // omega = v/r
+ // omega = 1/r v
+ // omega = 1/(trackwidth/2) v
+ // omega = 2/trackwidth v
+ //
+ // So multiplying by 2/trackwidth converts the angular gains from V/(rad/s)
+ // to V/m/s).
+ return IdentifyDrivetrainSystem(kVlinear, kAlinear,
+ kVangular * 2.0 / trackwidth * 1_rad,
+ kAangular * 2.0 / trackwidth * 1_rad);
}
/**
@@ -196,15 +299,23 @@
* @param motor Instance of DCMotor.
* @param J Moment of inertia.
* @param G Gear ratio from motor to carriage.
+ * @throws std::domain_error if J <= 0 or G <= 0.
*/
static LinearSystem<1, 1, 1> FlywheelSystem(DCMotor motor,
units::kilogram_square_meter_t J,
double G) {
- auto A = frc::MakeMatrix<1, 1>(
- (-std::pow(G, 2) * motor.Kt / (motor.Kv * motor.R * J)).to<double>());
- auto B = frc::MakeMatrix<1, 1>((G * motor.Kt / (motor.R * J)).to<double>());
- auto C = frc::MakeMatrix<1, 1>(1.0);
- auto D = frc::MakeMatrix<1, 1>(0.0);
+ if (J <= 0_kg_sq_m) {
+ throw std::domain_error("J must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw std::domain_error("G must be greater than zero.");
+ }
+
+ Eigen::Matrix<double, 1, 1> A{
+ (-std::pow(G, 2) * motor.Kt / (motor.Kv * motor.R * J)).value()};
+ Eigen::Matrix<double, 1, 1> B{(G * motor.Kt / (motor.R * J)).value()};
+ Eigen::Matrix<double, 1, 1> C{1.0};
+ Eigen::Matrix<double, 1, 1> D{0.0};
return LinearSystem<1, 1, 1>(A, B, C, D);
}
@@ -220,28 +331,46 @@
* @param m Drivetrain mass.
* @param r Wheel radius.
* @param rb Robot radius.
- * @param G Gear ratio from motor to wheel.
* @param J Moment of inertia.
+ * @param G Gear ratio from motor to wheel.
+ * @throws std::domain_error if m <= 0, r <= 0, rb <= 0, J <= 0, or
+ * G <= 0.
*/
static LinearSystem<2, 2, 2> DrivetrainVelocitySystem(
const DCMotor& motor, units::kilogram_t m, units::meter_t r,
units::meter_t rb, units::kilogram_square_meter_t J, double G) {
+ if (m <= 0_kg) {
+ throw std::domain_error("m must be greater than zero.");
+ }
+ if (r <= 0_m) {
+ throw std::domain_error("r must be greater than zero.");
+ }
+ if (rb <= 0_m) {
+ throw std::domain_error("rb must be greater than zero.");
+ }
+ if (J <= 0_kg_sq_m) {
+ throw std::domain_error("J must be greater than zero.");
+ }
+ if (G <= 0.0) {
+ throw std::domain_error("G must be greater than zero.");
+ }
+
auto C1 = -std::pow(G, 2) * motor.Kt /
(motor.Kv * motor.R * units::math::pow<2>(r));
auto C2 = G * motor.Kt / (motor.R * r);
- auto A = frc::MakeMatrix<2, 2>(
- ((1 / m + units::math::pow<2>(rb) / J) * C1).to<double>(),
- ((1 / m - units::math::pow<2>(rb) / J) * C1).to<double>(),
- ((1 / m - units::math::pow<2>(rb) / J) * C1).to<double>(),
- ((1 / m + units::math::pow<2>(rb) / J) * C1).to<double>());
- auto B = frc::MakeMatrix<2, 2>(
- ((1 / m + units::math::pow<2>(rb) / J) * C2).to<double>(),
- ((1 / m - units::math::pow<2>(rb) / J) * C2).to<double>(),
- ((1 / m - units::math::pow<2>(rb) / J) * C2).to<double>(),
- ((1 / m + units::math::pow<2>(rb) / J) * C2).to<double>());
- auto C = frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 1.0);
- auto D = frc::MakeMatrix<2, 2>(0.0, 0.0, 0.0, 0.0);
+ Eigen::Matrix<double, 2, 2> A{
+ {((1 / m + units::math::pow<2>(rb) / J) * C1).value(),
+ ((1 / m - units::math::pow<2>(rb) / J) * C1).value()},
+ {((1 / m - units::math::pow<2>(rb) / J) * C1).value(),
+ ((1 / m + units::math::pow<2>(rb) / J) * C1).value()}};
+ Eigen::Matrix<double, 2, 2> B{
+ {((1 / m + units::math::pow<2>(rb) / J) * C2).value(),
+ ((1 / m - units::math::pow<2>(rb) / J) * C2).value()},
+ {((1 / m - units::math::pow<2>(rb) / J) * C2).value(),
+ ((1 / m + units::math::pow<2>(rb) / J) * C2).value()}};
+ Eigen::Matrix<double, 2, 2> C{{1.0, 0.0}, {0.0, 1.0}};
+ Eigen::Matrix<double, 2, 2> D{{0.0, 0.0}, {0.0, 0.0}};
return LinearSystem<2, 2, 2>(A, B, C, D);
}
diff --git a/wpimath/src/main/native/include/frc/trajectory/Trajectory.h b/wpimath/src/main/native/include/frc/trajectory/Trajectory.h
index 023b2c2..2fad345 100644
--- a/wpimath/src/main/native/include/frc/trajectory/Trajectory.h
+++ b/wpimath/src/main/native/include/frc/trajectory/Trajectory.h
@@ -1,14 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <vector>
+#include <wpi/SymbolExports.h>
+
#include "frc/geometry/Pose2d.h"
#include "frc/geometry/Transform2d.h"
#include "units/acceleration.h"
@@ -26,12 +25,12 @@
* various States that represent the pose, curvature, time elapsed, velocity,
* and acceleration at that point.
*/
-class Trajectory {
+class WPILIB_DLLEXPORT Trajectory {
public:
/**
* Represents one point on the trajectory.
*/
- struct State {
+ struct WPILIB_DLLEXPORT State {
// The time elapsed since the beginning of the trajectory.
units::second_t t = 0_s;
@@ -123,6 +122,16 @@
Trajectory RelativeTo(const Pose2d& pose);
/**
+ * Concatenates another trajectory to the current trajectory. The user is
+ * responsible for making sure that the end pose of this trajectory and the
+ * start pose of the other trajectory match (if that is the desired behavior).
+ *
+ * @param other The trajectory to concatenate.
+ * @return The concatenated trajectory.
+ */
+ Trajectory operator+(const Trajectory& other) const;
+
+ /**
* Returns the initial pose of the trajectory.
*
* @return The initial pose of the trajectory.
@@ -164,8 +173,10 @@
}
};
+WPILIB_DLLEXPORT
void to_json(wpi::json& json, const Trajectory::State& state);
+WPILIB_DLLEXPORT
void from_json(const wpi::json& json, Trajectory::State& state);
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/trajectory/TrajectoryConfig.h b/wpimath/src/main/native/include/frc/trajectory/TrajectoryConfig.h
index 5bd7977..b1a0b52 100644
--- a/wpimath/src/main/native/include/frc/trajectory/TrajectoryConfig.h
+++ b/wpimath/src/main/native/include/frc/trajectory/TrajectoryConfig.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -11,6 +8,8 @@
#include <utility>
#include <vector>
+#include <wpi/SymbolExports.h>
+
#include "frc/kinematics/DifferentialDriveKinematics.h"
#include "frc/kinematics/MecanumDriveKinematics.h"
#include "frc/kinematics/SwerveDriveKinematics.h"
@@ -32,7 +31,7 @@
* have been defaulted to reasonable values (0, 0, {}, false). These values can
* be changed via the SetXXX methods.
*/
-class TrajectoryConfig {
+class WPILIB_DLLEXPORT TrajectoryConfig {
public:
/**
* Constructs a config object.
diff --git a/wpimath/src/main/native/include/frc/trajectory/TrajectoryGenerator.h b/wpimath/src/main/native/include/frc/trajectory/TrajectoryGenerator.h
index f1747fd..84ec0e0 100644
--- a/wpimath/src/main/native/include/frc/trajectory/TrajectoryGenerator.h
+++ b/wpimath/src/main/native/include/frc/trajectory/TrajectoryGenerator.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -12,6 +9,8 @@
#include <utility>
#include <vector>
+#include <wpi/SymbolExports.h>
+
#include "frc/spline/SplineParameterizer.h"
#include "frc/trajectory/Trajectory.h"
#include "frc/trajectory/TrajectoryConfig.h"
@@ -22,7 +21,7 @@
/**
* Helper class used to generate trajectories with various constraints.
*/
-class TrajectoryGenerator {
+class WPILIB_DLLEXPORT TrajectoryGenerator {
public:
using PoseWithCurvature = std::pair<Pose2d, units::curvature_t>;
diff --git a/wpimath/src/main/native/include/frc/trajectory/TrajectoryParameterizer.h b/wpimath/src/main/native/include/frc/trajectory/TrajectoryParameterizer.h
index 378f007..eea1c07 100644
--- a/wpimath/src/main/native/include/frc/trajectory/TrajectoryParameterizer.h
+++ b/wpimath/src/main/native/include/frc/trajectory/TrajectoryParameterizer.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
/*
* MIT License
@@ -35,6 +32,8 @@
#include <utility>
#include <vector>
+#include <wpi/SymbolExports.h>
+
#include "frc/trajectory/Trajectory.h"
#include "frc/trajectory/constraint/TrajectoryConstraint.h"
@@ -42,7 +41,7 @@
/**
* Class used to parameterize a trajectory by time.
*/
-class TrajectoryParameterizer {
+class WPILIB_DLLEXPORT TrajectoryParameterizer {
public:
using PoseWithCurvature = std::pair<Pose2d, units::curvature_t>;
diff --git a/wpimath/src/main/native/include/frc/trajectory/TrajectoryUtil.h b/wpimath/src/main/native/include/frc/trajectory/TrajectoryUtil.h
index f05cd14..6e52a09 100644
--- a/wpimath/src/main/native/include/frc/trajectory/TrajectoryUtil.h
+++ b/wpimath/src/main/native/include/frc/trajectory/TrajectoryUtil.h
@@ -1,21 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <string>
+#include <string_view>
-#include <wpi/StringRef.h>
-#include <wpi/Twine.h>
+#include <wpi/SymbolExports.h>
#include "frc/trajectory/Trajectory.h"
namespace frc {
-class TrajectoryUtil {
+class WPILIB_DLLEXPORT TrajectoryUtil {
public:
TrajectoryUtil() = delete;
@@ -24,11 +21,9 @@
*
* @param trajectory the trajectory to export
* @param path the path of the file to export to
- *
- * @return The interpolated state.
*/
static void ToPathweaverJson(const Trajectory& trajectory,
- const wpi::Twine& path);
+ std::string_view path);
/**
* Imports a Trajectory from a PathWeaver-style JSON file.
*
@@ -36,24 +31,24 @@
*
* @return The trajectory represented by the file.
*/
- static Trajectory FromPathweaverJson(const wpi::Twine& path);
+ static Trajectory FromPathweaverJson(std::string_view path);
/**
- * Deserializes a Trajectory from PathWeaver-style JSON.
-
- * @param json the string containing the serialized JSON
-
- * @return the trajectory represented by the JSON
- */
+ * Deserializes a Trajectory from PathWeaver-style JSON.
+ *
+ * @param trajectory the trajectory to export
+ *
+ * @return the string containing the serialized JSON
+ */
static std::string SerializeTrajectory(const Trajectory& trajectory);
/**
* Serializes a Trajectory to PathWeaver-style JSON.
-
- * @param trajectory the trajectory to export
-
- * @return the string containing the serialized JSON
+ *
+ * @param jsonStr the string containing the serialized JSON
+ *
+ * @return the trajectory represented by the JSON
*/
- static Trajectory DeserializeTrajectory(wpi::StringRef json_str);
+ static Trajectory DeserializeTrajectory(std::string_view jsonStr);
};
} // namespace frc
diff --git a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h
index 5f9f1b8..0e623eb 100644
--- a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h
+++ b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc
index 8718cc0..47a598e 100644
--- a/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc
+++ b/wpimath/src/main/native/include/frc/trajectory/TrapezoidProfile.inc
@@ -1,14 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <algorithm>
+#include "frc/trajectory/TrapezoidProfile.h"
#include "units/math.h"
namespace frc {
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/CentripetalAccelerationConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/CentripetalAccelerationConstraint.h
index 4f897ba..5ef15a4 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/CentripetalAccelerationConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/CentripetalAccelerationConstraint.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/trajectory/constraint/TrajectoryConstraint.h"
#include "units/acceleration.h"
#include "units/curvature.h"
@@ -23,7 +22,8 @@
* robot to slow down around tight turns, making it easier to track trajectories
* with sharp turns.
*/
-class CentripetalAccelerationConstraint : public TrajectoryConstraint {
+class WPILIB_DLLEXPORT CentripetalAccelerationConstraint
+ : public TrajectoryConstraint {
public:
explicit CentripetalAccelerationConstraint(
units::meters_per_second_squared_t maxCentripetalAcceleration);
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveKinematicsConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveKinematicsConstraint.h
index f23c1e2..ad643bf 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveKinematicsConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveKinematicsConstraint.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/kinematics/DifferentialDriveKinematics.h"
#include "frc/trajectory/constraint/TrajectoryConstraint.h"
#include "units/velocity.h"
@@ -18,7 +17,8 @@
* commanded velocities for both sides of the drivetrain stay below a certain
* limit.
*/
-class DifferentialDriveKinematicsConstraint : public TrajectoryConstraint {
+class WPILIB_DLLEXPORT DifferentialDriveKinematicsConstraint
+ : public TrajectoryConstraint {
public:
DifferentialDriveKinematicsConstraint(
const DifferentialDriveKinematics& kinematics,
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveVoltageConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveVoltageConstraint.h
index 23c690d..06d0e50 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveVoltageConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/DifferentialDriveVoltageConstraint.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/controller/SimpleMotorFeedforward.h"
#include "frc/kinematics/DifferentialDriveKinematics.h"
#include "frc/trajectory/constraint/TrajectoryConstraint.h"
@@ -20,7 +19,8 @@
* acceleration of any wheel of the robot while following the trajectory is
* never higher than what can be achieved with the given maximum voltage.
*/
-class DifferentialDriveVoltageConstraint : public TrajectoryConstraint {
+class WPILIB_DLLEXPORT DifferentialDriveVoltageConstraint
+ : public TrajectoryConstraint {
public:
/**
* Creates a new DifferentialDriveVoltageConstraint.
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/EllipticalRegionConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/EllipticalRegionConstraint.h
index 78bc569..e2ef37b 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/EllipticalRegionConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/EllipticalRegionConstraint.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/MaxVelocityConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/MaxVelocityConstraint.h
index 7a30881..b7375d5 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/MaxVelocityConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/MaxVelocityConstraint.h
@@ -1,12 +1,11 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include <wpi/SymbolExports.h>
+
#include "frc/trajectory/constraint/TrajectoryConstraint.h"
#include "units/math.h"
#include "units/velocity.h"
@@ -17,26 +16,21 @@
* with the EllipticalRegionConstraint or RectangularRegionConstraint to enforce
* a max velocity within a region.
*/
-class MaxVelocityConstraint : public TrajectoryConstraint {
+class WPILIB_DLLEXPORT MaxVelocityConstraint : public TrajectoryConstraint {
public:
/**
* Constructs a new MaxVelocityConstraint.
*
* @param maxVelocity The max velocity.
*/
- explicit MaxVelocityConstraint(units::meters_per_second_t maxVelocity)
- : m_maxVelocity(units::math::abs(maxVelocity)) {}
+ explicit MaxVelocityConstraint(units::meters_per_second_t maxVelocity);
units::meters_per_second_t MaxVelocity(
const Pose2d& pose, units::curvature_t curvature,
- units::meters_per_second_t velocity) const override {
- return m_maxVelocity;
- }
+ units::meters_per_second_t velocity) const override;
MinMax MinMaxAcceleration(const Pose2d& pose, units::curvature_t curvature,
- units::meters_per_second_t speed) const override {
- return {};
- }
+ units::meters_per_second_t speed) const override;
private:
units::meters_per_second_t m_maxVelocity;
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/MecanumDriveKinematicsConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/MecanumDriveKinematicsConstraint.h
index 0166f56..816d8ef 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/MecanumDriveKinematicsConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/MecanumDriveKinematicsConstraint.h
@@ -1,14 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <cmath>
+#include <wpi/SymbolExports.h>
+
#include "frc/kinematics/MecanumDriveKinematics.h"
#include "frc/trajectory/constraint/TrajectoryConstraint.h"
#include "units/velocity.h"
@@ -20,7 +19,8 @@
* commanded velocities for wheels of the drivetrain stay below a certain
* limit.
*/
-class MecanumDriveKinematicsConstraint : public TrajectoryConstraint {
+class WPILIB_DLLEXPORT MecanumDriveKinematicsConstraint
+ : public TrajectoryConstraint {
public:
MecanumDriveKinematicsConstraint(const MecanumDriveKinematics& kinematics,
units::meters_per_second_t maxSpeed);
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/RectangularRegionConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/RectangularRegionConstraint.h
index 203b237..c5bc559 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/RectangularRegionConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/RectangularRegionConstraint.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h
index 0f43e29..67e9fc9 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
@@ -15,13 +12,12 @@
namespace frc {
-template <size_t NumModules>
-
/**
* A class that enforces constraints on the swerve drive kinematics.
* This can be used to ensure that the trajectory is constructed so that the
* commanded velocities of the wheels stay below a certain limit.
*/
+template <size_t NumModules>
class SwerveDriveKinematicsConstraint : public TrajectoryConstraint {
public:
SwerveDriveKinematicsConstraint(
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc
index 1af8511..1a1e4b8 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/SwerveDriveKinematicsConstraint.inc
@@ -1,23 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
+#include "frc/trajectory/constraint/SwerveDriveKinematicsConstraint.h"
#include "units/math.h"
namespace frc {
template <size_t NumModules>
-
-/**
- * A class that enforces constraints on the swerve drive kinematics.
- * This can be used to ensure that the trajectory is constructed so that the
- * commanded velocities of the wheels stay below a certain limit.
- */
SwerveDriveKinematicsConstraint<NumModules>::SwerveDriveKinematicsConstraint(
const frc::SwerveDriveKinematics<NumModules>& kinematics,
units::meters_per_second_t maxSpeed)
diff --git a/wpimath/src/main/native/include/frc/trajectory/constraint/TrajectoryConstraint.h b/wpimath/src/main/native/include/frc/trajectory/constraint/TrajectoryConstraint.h
index b5548c5..47ca820 100644
--- a/wpimath/src/main/native/include/frc/trajectory/constraint/TrajectoryConstraint.h
+++ b/wpimath/src/main/native/include/frc/trajectory/constraint/TrajectoryConstraint.h
@@ -1,14 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <limits>
+#include <wpi/SymbolExports.h>
+
#include "frc/geometry/Pose2d.h"
#include "frc/spline/Spline.h"
#include "units/acceleration.h"
@@ -20,7 +19,7 @@
* An interface for defining user-defined velocity and acceleration constraints
* while generating trajectories.
*/
-class TrajectoryConstraint {
+class WPILIB_DLLEXPORT TrajectoryConstraint {
public:
TrajectoryConstraint() = default;
diff --git a/wpimath/src/main/native/include/units/acceleration.h b/wpimath/src/main/native/include/units/acceleration.h
index 5427160..a0d12b0 100644
--- a/wpimath/src/main/native/include/units/acceleration.h
+++ b/wpimath/src/main/native/include/units/acceleration.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/angle.h b/wpimath/src/main/native/include/units/angle.h
index a0f802f..876bd60 100644
--- a/wpimath/src/main/native/include/units/angle.h
+++ b/wpimath/src/main/native/include/units/angle.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/angular_acceleration.h b/wpimath/src/main/native/include/units/angular_acceleration.h
index 4b1af0f..6a411c4 100644
--- a/wpimath/src/main/native/include/units/angular_acceleration.h
+++ b/wpimath/src/main/native/include/units/angular_acceleration.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/units/angular_velocity.h b/wpimath/src/main/native/include/units/angular_velocity.h
index 580d021..16f39e1 100644
--- a/wpimath/src/main/native/include/units/angular_velocity.h
+++ b/wpimath/src/main/native/include/units/angular_velocity.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/area.h b/wpimath/src/main/native/include/units/area.h
index 1bdd3e3..e4d82d9 100644
--- a/wpimath/src/main/native/include/units/area.h
+++ b/wpimath/src/main/native/include/units/area.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/base.h b/wpimath/src/main/native/include/units/base.h
index 579ec88..f2d45cf 100644
--- a/wpimath/src/main/native/include/units/base.h
+++ b/wpimath/src/main/native/include/units/base.h
@@ -74,36 +74,42 @@
#include <cmath>
#include <limits>
-#if !defined(UNIT_LIB_DISABLE_IOSTREAM)
+#if defined(UNIT_LIB_ENABLE_IOSTREAM)
#include <iostream>
- #include <string>
#include <locale>
-
- //------------------------------
- // STRING FORMATTER
- //------------------------------
-
- namespace units
- {
- namespace detail
- {
- template <typename T> std::string to_string(const T& t)
- {
- std::string str{ std::to_string(t) };
- int offset{ 1 };
-
- // remove trailing decimal points for integer value units. Locale aware!
- struct lconv * lc;
- lc = localeconv();
- char decimalPoint = *lc->decimal_point;
- if (str.find_last_not_of('0') == str.find(decimalPoint)) { offset = 0; }
- str.erase(str.find_last_not_of('0') + offset, std::string::npos);
- return str;
- }
- }
- }
+ #include <string>
+#else
+ #include <locale>
+ #include <string>
+ #include <fmt/format.h>
#endif
+#include <wpi/SymbolExports.h>
+
+//------------------------------
+// STRING FORMATTER
+//------------------------------
+
+namespace units
+{
+ namespace detail
+ {
+ template <typename T> std::string to_string(const T& t)
+ {
+ std::string str{ std::to_string(t) };
+ int offset{ 1 };
+
+ // remove trailing decimal points for integer value units. Locale aware!
+ struct lconv * lc;
+ lc = localeconv();
+ char decimalPoint = *lc->decimal_point;
+ if (str.find_last_not_of('0') == str.find(decimalPoint)) { offset = 0; }
+ str.erase(str.find_last_not_of('0') + offset, std::string::npos);
+ return str;
+ }
+ }
+}
+
namespace units
{
template<typename T> inline constexpr const char* name(const T&);
@@ -172,10 +178,33 @@
* @param namespaceName namespace in which the new units will be encapsulated.
* @param nameSingular singular version of the unit name, e.g. 'meter'
* @param abbrev - abbreviated unit name, e.g. 'm'
- * @note When UNIT_LIB_DISABLE_IOSTREAM is defined, the macro does not generate any code
+ * @note When UNIT_LIB_ENABLE_IOSTREAM isn't defined, the macro does not generate any code
*/
-#if defined(UNIT_LIB_DISABLE_IOSTREAM)
- #define UNIT_ADD_IO(namespaceName, nameSingular, abbrev)
+#if !defined(UNIT_LIB_ENABLE_IOSTREAM)
+ #define UNIT_ADD_IO(namespaceName, nameSingular, abbrev)\
+ }\
+ template <>\
+ struct fmt::formatter<units::namespaceName::nameSingular ## _t> \
+ : fmt::formatter<double> \
+ {\
+ template <typename FormatContext>\
+ auto format(const units::namespaceName::nameSingular ## _t& obj,\
+ FormatContext& ctx) -> decltype(ctx.out()) \
+ {\
+ auto out = ctx.out();\
+ out = fmt::formatter<double>::format(obj(), ctx);\
+ return fmt::format_to(out, " " #abbrev);\
+ }\
+ };\
+ namespace units\
+ {\
+ namespace namespaceName\
+ {\
+ inline std::string to_string(const nameSingular ## _t& obj)\
+ {\
+ return units::detail::to_string(obj()) + std::string(" "#abbrev);\
+ }\
+ }
#else
#define UNIT_ADD_IO(namespaceName, nameSingular, abbrev)\
namespace namespaceName\
@@ -2180,7 +2209,7 @@
return UnitType(value);
}
-#if !defined(UNIT_LIB_DISABLE_IOSTREAM)
+#if defined(UNIT_LIB_ENABLE_IOSTREAM)
template<class Units, typename T, template<typename> class NonLinearScale>
inline std::ostream& operator<<(std::ostream& os, const unit_t<Units, T, NonLinearScale>& obj) noexcept
{
@@ -2815,11 +2844,31 @@
namespace dimensionless
{
typedef unit_t<scalar, UNIT_LIB_DEFAULT_TYPE, decibel_scale> dB_t;
-#if !defined(UNIT_LIB_DISABLE_IOSTREAM)
+#if defined(UNIT_LIB_ENABLE_IOSTREAM)
inline std::ostream& operator<<(std::ostream& os, const dB_t& obj) { os << obj() << " dB"; return os; }
-#endif
typedef dB_t dBi_t;
}
+#else
+}
+}
+template <>
+struct fmt::formatter<units::dimensionless::dB_t> : fmt::formatter<double>
+{
+ template <typename FormatContext>
+ auto format(const units::dimensionless::dB_t& obj,
+ FormatContext& ctx) -> decltype(ctx.out())
+ {
+ auto out = ctx.out();
+ out = fmt::formatter<double>::format(obj(), ctx);
+ return fmt::format_to(out, " dB");
+ }
+};
+
+namespace units {
+namespace dimensionless {
+ typedef dB_t dBi_t;
+ }
+#endif
//------------------------------
// DECIBEL ARITHMETIC
@@ -3365,3 +3414,5 @@
namespace units::literals {}
using namespace units::literals;
#endif // UNIT_HAS_LITERAL_SUPPORT
+
+#include "frc/fmt/Units.h"
diff --git a/wpimath/src/main/native/include/units/capacitance.h b/wpimath/src/main/native/include/units/capacitance.h
index feaf88c..e9e22f6 100644
--- a/wpimath/src/main/native/include/units/capacitance.h
+++ b/wpimath/src/main/native/include/units/capacitance.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/charge.h b/wpimath/src/main/native/include/units/charge.h
index 42064b0..841f3a4 100644
--- a/wpimath/src/main/native/include/units/charge.h
+++ b/wpimath/src/main/native/include/units/charge.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/concentration.h b/wpimath/src/main/native/include/units/concentration.h
index b276f82..3128ff6 100644
--- a/wpimath/src/main/native/include/units/concentration.h
+++ b/wpimath/src/main/native/include/units/concentration.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/conductance.h b/wpimath/src/main/native/include/units/conductance.h
index d0508c4..d2abff1 100644
--- a/wpimath/src/main/native/include/units/conductance.h
+++ b/wpimath/src/main/native/include/units/conductance.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/constants.h b/wpimath/src/main/native/include/units/constants.h
index 7d5c49f..efadaf7 100644
--- a/wpimath/src/main/native/include/units/constants.h
+++ b/wpimath/src/main/native/include/units/constants.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
@@ -46,13 +43,7 @@
#include "units/time.h"
#include "units/velocity.h"
-namespace units {
-/**
- * @brief namespace for physical constants like PI and Avogadro's Number.
- * @sa See unit_t for more information on unit type containers.
- */
-#if !defined(DISABLE_PREDEFINED_UNITS)
-namespace constants {
+namespace units::constants {
/**
* @name Unit Containers
* @anchor constantContainers
@@ -108,6 +99,4 @@
(15 * math::cpow<3>(h) * math::cpow<2>(c) *
math::cpow<4>(N_A))); ///< Stefan-Boltzmann constant.
/** @} */
-} // namespace constants
-#endif
-} // namespace units
+} // namespace units::constants
diff --git a/wpimath/src/main/native/include/units/current.h b/wpimath/src/main/native/include/units/current.h
index 54a408c..b187bb3 100644
--- a/wpimath/src/main/native/include/units/current.h
+++ b/wpimath/src/main/native/include/units/current.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/curvature.h b/wpimath/src/main/native/include/units/curvature.h
index 233ad61..062b09a 100644
--- a/wpimath/src/main/native/include/units/curvature.h
+++ b/wpimath/src/main/native/include/units/curvature.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/units/data.h b/wpimath/src/main/native/include/units/data.h
index 386c0c2..90691a2 100644
--- a/wpimath/src/main/native/include/units/data.h
+++ b/wpimath/src/main/native/include/units/data.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/data_transfer_rate.h b/wpimath/src/main/native/include/units/data_transfer_rate.h
index 67de063..29fb028 100644
--- a/wpimath/src/main/native/include/units/data_transfer_rate.h
+++ b/wpimath/src/main/native/include/units/data_transfer_rate.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/density.h b/wpimath/src/main/native/include/units/density.h
index 2509f49..8616517 100644
--- a/wpimath/src/main/native/include/units/density.h
+++ b/wpimath/src/main/native/include/units/density.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/dimensionless.h b/wpimath/src/main/native/include/units/dimensionless.h
index 64f75ba..26c118b 100644
--- a/wpimath/src/main/native/include/units/dimensionless.h
+++ b/wpimath/src/main/native/include/units/dimensionless.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/energy.h b/wpimath/src/main/native/include/units/energy.h
index c206e5d..36996e2 100644
--- a/wpimath/src/main/native/include/units/energy.h
+++ b/wpimath/src/main/native/include/units/energy.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/force.h b/wpimath/src/main/native/include/units/force.h
index 9813958..2c2769f 100644
--- a/wpimath/src/main/native/include/units/force.h
+++ b/wpimath/src/main/native/include/units/force.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/frequency.h b/wpimath/src/main/native/include/units/frequency.h
index f030329..f1795d5 100644
--- a/wpimath/src/main/native/include/units/frequency.h
+++ b/wpimath/src/main/native/include/units/frequency.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/illuminance.h b/wpimath/src/main/native/include/units/illuminance.h
index 976f6b5..f653ec6 100644
--- a/wpimath/src/main/native/include/units/illuminance.h
+++ b/wpimath/src/main/native/include/units/illuminance.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/impedance.h b/wpimath/src/main/native/include/units/impedance.h
index b4b92ad..abe0375 100644
--- a/wpimath/src/main/native/include/units/impedance.h
+++ b/wpimath/src/main/native/include/units/impedance.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/inductance.h b/wpimath/src/main/native/include/units/inductance.h
index 6a5be7f..1e6d9f6 100644
--- a/wpimath/src/main/native/include/units/inductance.h
+++ b/wpimath/src/main/native/include/units/inductance.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/length.h b/wpimath/src/main/native/include/units/length.h
index 637797d..8b75c7c 100644
--- a/wpimath/src/main/native/include/units/length.h
+++ b/wpimath/src/main/native/include/units/length.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/luminous_flux.h b/wpimath/src/main/native/include/units/luminous_flux.h
index ca7a079..31ca391 100644
--- a/wpimath/src/main/native/include/units/luminous_flux.h
+++ b/wpimath/src/main/native/include/units/luminous_flux.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/luminous_intensity.h b/wpimath/src/main/native/include/units/luminous_intensity.h
index f907d2e..7d48dfe 100644
--- a/wpimath/src/main/native/include/units/luminous_intensity.h
+++ b/wpimath/src/main/native/include/units/luminous_intensity.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/magnetic_field_strength.h b/wpimath/src/main/native/include/units/magnetic_field_strength.h
index e7a7086..5e953e9 100644
--- a/wpimath/src/main/native/include/units/magnetic_field_strength.h
+++ b/wpimath/src/main/native/include/units/magnetic_field_strength.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/magnetic_flux.h b/wpimath/src/main/native/include/units/magnetic_flux.h
index 739b9e7..6516172 100644
--- a/wpimath/src/main/native/include/units/magnetic_flux.h
+++ b/wpimath/src/main/native/include/units/magnetic_flux.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/mass.h b/wpimath/src/main/native/include/units/mass.h
index 21fa3b5..f81e68a 100644
--- a/wpimath/src/main/native/include/units/mass.h
+++ b/wpimath/src/main/native/include/units/mass.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/math.h b/wpimath/src/main/native/include/units/math.h
index ccb3a62..995335b 100644
--- a/wpimath/src/main/native/include/units/math.h
+++ b/wpimath/src/main/native/include/units/math.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
@@ -31,13 +28,10 @@
#include <cmath>
-#include <wpi/math>
-
#include "units/angle.h"
#include "units/base.h"
#include "units/dimensionless.h"
-namespace units {
//----------------------------------
// UNIT-ENABLED CMATH FUNCTIONS
//----------------------------------
@@ -48,7 +42,7 @@
* rounding functions, etc.
* @sa See `unit_t` for more information on unit type containers.
*/
-namespace math {
+namespace units::math {
//----------------------------------
// TRIGONOMETRIC FUNCTIONS
//----------------------------------
@@ -755,25 +749,4 @@
"Unit types are not compatible.");
return resultType(std::fma(x(), y(), resultType(z)()));
}
-
-/**
- * Constrains theta to within the range (-pi, pi].
- *
- * @param theta Angle to normalize.
- */
-constexpr units::radian_t NormalizeAngle(units::radian_t theta) {
- units::radian_t pi(wpi::math::pi);
-
- // Constrain theta to within (-3pi, pi)
- const int n_pi_pos = (theta + pi) / 2.0 / pi;
- theta = theta - units::radian_t{n_pi_pos * 2.0 * wpi::math::pi};
-
- // Cut off the bottom half of the above range to constrain within
- // (-pi, pi]
- const int n_pi_neg = (theta - pi) / 2.0 / pi;
- theta = theta - units::radian_t{n_pi_neg * 2.0 * wpi::math::pi};
-
- return theta;
-}
-} // namespace math
-} // namespace units
+} // namespace units::math
diff --git a/wpimath/src/main/native/include/units/moment_of_inertia.h b/wpimath/src/main/native/include/units/moment_of_inertia.h
index 938a635..9d30852 100644
--- a/wpimath/src/main/native/include/units/moment_of_inertia.h
+++ b/wpimath/src/main/native/include/units/moment_of_inertia.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/src/main/native/include/units/power.h b/wpimath/src/main/native/include/units/power.h
index d1a9504..b4c5f13 100644
--- a/wpimath/src/main/native/include/units/power.h
+++ b/wpimath/src/main/native/include/units/power.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/pressure.h b/wpimath/src/main/native/include/units/pressure.h
index c14bae1..63c0e37 100644
--- a/wpimath/src/main/native/include/units/pressure.h
+++ b/wpimath/src/main/native/include/units/pressure.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/radiation.h b/wpimath/src/main/native/include/units/radiation.h
index b631336..84f8eed 100644
--- a/wpimath/src/main/native/include/units/radiation.h
+++ b/wpimath/src/main/native/include/units/radiation.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
@@ -38,6 +35,7 @@
#include "units/base.h"
#include "units/energy.h"
#include "units/frequency.h"
+#include "units/mass.h"
namespace units {
/**
diff --git a/wpimath/src/main/native/include/units/solid_angle.h b/wpimath/src/main/native/include/units/solid_angle.h
index 0e38f55..2e0182b 100644
--- a/wpimath/src/main/native/include/units/solid_angle.h
+++ b/wpimath/src/main/native/include/units/solid_angle.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/substance.h b/wpimath/src/main/native/include/units/substance.h
index c774497..8691818 100644
--- a/wpimath/src/main/native/include/units/substance.h
+++ b/wpimath/src/main/native/include/units/substance.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/temperature.h b/wpimath/src/main/native/include/units/temperature.h
index 25a9b98..24f22a0 100644
--- a/wpimath/src/main/native/include/units/temperature.h
+++ b/wpimath/src/main/native/include/units/temperature.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/time.h b/wpimath/src/main/native/include/units/time.h
index 13e66c4..6366123 100644
--- a/wpimath/src/main/native/include/units/time.h
+++ b/wpimath/src/main/native/include/units/time.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/torque.h b/wpimath/src/main/native/include/units/torque.h
index 58f4ca3..42ab326 100644
--- a/wpimath/src/main/native/include/units/torque.h
+++ b/wpimath/src/main/native/include/units/torque.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/velocity.h b/wpimath/src/main/native/include/units/velocity.h
index 5a0ebcb..d63d1e6 100644
--- a/wpimath/src/main/native/include/units/velocity.h
+++ b/wpimath/src/main/native/include/units/velocity.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/voltage.h b/wpimath/src/main/native/include/units/voltage.h
index 917c52a..605baed 100644
--- a/wpimath/src/main/native/include/units/voltage.h
+++ b/wpimath/src/main/native/include/units/voltage.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/units/volume.h b/wpimath/src/main/native/include/units/volume.h
index f13fdef..c361a8f 100644
--- a/wpimath/src/main/native/include/units/volume.h
+++ b/wpimath/src/main/native/include/units/volume.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
// Copyright (c) 2016 Nic Holthaus
//
diff --git a/wpimath/src/main/native/include/wpimath/MathShared.h b/wpimath/src/main/native/include/wpimath/MathShared.h
index ea23a88..f4a1795 100644
--- a/wpimath/src/main/native/include/wpimath/MathShared.h
+++ b/wpimath/src/main/native/include/wpimath/MathShared.h
@@ -1,15 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <memory>
-#include <wpi/Twine.h>
+#include <fmt/format.h>
+#include <wpi/SymbolExports.h>
namespace wpi::math {
@@ -21,24 +19,52 @@
kFilter_Linear,
kOdometry_DifferentialDrive,
kOdometry_SwerveDrive,
- kOdometry_MecanumDrive
+ kOdometry_MecanumDrive,
+ kController_PIDController2,
+ kController_ProfiledPIDController,
};
-class MathShared {
+class WPILIB_DLLEXPORT MathShared {
public:
virtual ~MathShared() = default;
- virtual void ReportError(const wpi::Twine& error) = 0;
+ virtual void ReportErrorV(fmt::string_view format, fmt::format_args args) = 0;
+ virtual void ReportWarningV(fmt::string_view format,
+ fmt::format_args args) = 0;
virtual void ReportUsage(MathUsageId id, int count) = 0;
+
+ template <typename S, typename... Args>
+ inline void ReportError(const S& format, Args&&... args) {
+ ReportErrorV(format, fmt::make_args_checked<Args...>(format, args...));
+ }
+
+ template <typename S, typename... Args>
+ inline void ReportWarning(const S& format, Args&&... args) {
+ ReportWarningV(format, fmt::make_args_checked<Args...>(format, args...));
+ }
};
-class MathSharedStore {
+class WPILIB_DLLEXPORT MathSharedStore {
public:
static MathShared& GetMathShared();
static void SetMathShared(std::unique_ptr<MathShared> shared);
- static void ReportError(const wpi::Twine& error) {
- GetMathShared().ReportError(error);
+ static void ReportErrorV(fmt::string_view format, fmt::format_args args) {
+ GetMathShared().ReportErrorV(format, args);
+ }
+
+ template <typename S, typename... Args>
+ static inline void ReportError(const S& format, Args&&... args) {
+ ReportErrorV(format, fmt::make_args_checked<Args...>(format, args...));
+ }
+
+ static void ReportWarningV(fmt::string_view format, fmt::format_args args) {
+ GetMathShared().ReportWarningV(format, args);
+ }
+
+ template <typename S, typename... Args>
+ static inline void ReportWarning(const S& format, Args&&... args) {
+ ReportWarningV(format, fmt::make_args_checked<Args...>(format, args...));
}
static void ReportUsage(MathUsageId id, int count) {
diff --git a/wpimath/src/test/java/edu/wpi/first/math/DrakeTest.java b/wpimath/src/test/java/edu/wpi/first/math/DrakeTest.java
index 2697c6c..4140fae 100644
--- a/wpimath/src/test/java/edu/wpi/first/math/DrakeTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/DrakeTest.java
@@ -1,19 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
-import org.ejml.simple.SimpleMatrix;
-import org.junit.jupiter.api.Test;
-
-
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.ejml.simple.SimpleMatrix;
+import org.junit.jupiter.api.Test;
+
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public class DrakeTest {
public static void assertMatrixEqual(SimpleMatrix A, SimpleMatrix B) {
@@ -24,8 +20,8 @@
}
}
- private boolean solveDAREandVerify(SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q,
- SimpleMatrix R) {
+ private boolean solveDAREandVerify(
+ SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) {
var X = Drake.discreteAlgebraicRiccatiEquation(A, B, Q, R);
// expect that x is the same as it's transpose
@@ -33,11 +29,19 @@
assertMatrixEqual(X, X.transpose());
// Verify that this is a solution to the DARE.
- SimpleMatrix Y = A.transpose().mult(X).mult(A)
+ SimpleMatrix Y =
+ A.transpose()
+ .mult(X)
+ .mult(A)
.minus(X)
- .minus(A.transpose().mult(X).mult(B)
- .mult(((B.transpose().mult(X).mult(B)).plus(R))
- .invert()).mult(B.transpose()).mult(X).mult(A))
+ .minus(
+ A.transpose()
+ .mult(X)
+ .mult(B)
+ .mult(((B.transpose().mult(X).mult(B)).plus(R)).invert())
+ .mult(B.transpose())
+ .mult(X)
+ .mult(A))
.plus(Q);
assertMatrixEqual(Y, new SimpleMatrix(Y.numRows(), Y.numCols()));
@@ -50,20 +54,21 @@
int m1 = 1;
// we know from Scipy that this should be [[0.05048525 0.10097051 0.20194102 0.40388203]]
- SimpleMatrix A1 = new SimpleMatrix(n1, n1, true, new double[]{0.5, 1, 0, 0, 0, 0, 1,
- 0, 0, 0, 0, 1, 0, 0, 0, 0}).transpose();
- SimpleMatrix B1 = new SimpleMatrix(n1, m1, true, new double[]{0, 0, 0, 1});
- SimpleMatrix Q1 = new SimpleMatrix(n1, n1, true, new double[]{1, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0});
- SimpleMatrix R1 = new SimpleMatrix(m1, m1, true, new double[]{0.25});
+ SimpleMatrix A1 =
+ new SimpleMatrix(
+ n1, n1, true, new double[] {0.5, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0})
+ .transpose();
+ SimpleMatrix B1 = new SimpleMatrix(n1, m1, true, new double[] {0, 0, 0, 1});
+ SimpleMatrix Q1 =
+ new SimpleMatrix(
+ n1, n1, true, new double[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ SimpleMatrix R1 = new SimpleMatrix(m1, m1, true, new double[] {0.25});
assertTrue(solveDAREandVerify(A1, B1, Q1, R1));
- SimpleMatrix A2 = new SimpleMatrix(2, 2, true, new double[]{1, 1, 0, 1});
- SimpleMatrix B2 = new SimpleMatrix(2, 1, true, new double[]{0, 1});
- SimpleMatrix Q2 = new SimpleMatrix(2, 2, true, new double[]{1, 0, 0, 0});
- SimpleMatrix R2 = new SimpleMatrix(1, 1, true, new double[]{0.3});
+ SimpleMatrix A2 = new SimpleMatrix(2, 2, true, new double[] {1, 1, 0, 1});
+ SimpleMatrix B2 = new SimpleMatrix(2, 1, true, new double[] {0, 1});
+ SimpleMatrix Q2 = new SimpleMatrix(2, 2, true, new double[] {1, 0, 0, 0});
+ SimpleMatrix R2 = new SimpleMatrix(1, 1, true, new double[] {0.3});
assertTrue(solveDAREandVerify(A2, B2, Q2, R2));
-
}
}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/MathUtilTest.java b/wpimath/src/test/java/edu/wpi/first/math/MathUtilTest.java
new file mode 100644
index 0000000..bb116ce
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/MathUtilTest.java
@@ -0,0 +1,70 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.jupiter.api.Test;
+
+class MathUtilTest {
+ @Test
+ void testApplyDeadband() {
+ // < 0
+ assertEquals(-1.0, MathUtil.applyDeadband(-1.0, 0.02));
+ assertEquals((-0.03 + 0.02) / (1.0 - 0.02), MathUtil.applyDeadband(-0.03, 0.02));
+ assertEquals(0.0, MathUtil.applyDeadband(-0.02, 0.02));
+ assertEquals(0.0, MathUtil.applyDeadband(-0.01, 0.02));
+
+ // == 0
+ assertEquals(0.0, MathUtil.applyDeadband(0.0, 0.02));
+
+ // > 0
+ assertEquals(0.0, MathUtil.applyDeadband(0.01, 0.02));
+ assertEquals(0.0, MathUtil.applyDeadband(0.02, 0.02));
+ assertEquals((0.03 - 0.02) / (1.0 - 0.02), MathUtil.applyDeadband(0.03, 0.02));
+ assertEquals(1.0, MathUtil.applyDeadband(1.0, 0.02));
+ }
+
+ @Test
+ void testInputModulus() {
+ // These tests check error wrapping. That is, the result of wrapping the
+ // result of an angle reference minus the measurement.
+
+ // Test symmetric range
+ assertEquals(-20.0, MathUtil.inputModulus(170.0 - (-170.0), -180.0, 180.0));
+ assertEquals(-20.0, MathUtil.inputModulus(170.0 + 360.0 - (-170.0), -180.0, 180.0));
+ assertEquals(-20.0, MathUtil.inputModulus(170.0 - (-170.0 + 360.0), -180.0, 180.0));
+ assertEquals(20.0, MathUtil.inputModulus(-170.0 - 170.0, -180.0, 180.0));
+ assertEquals(20.0, MathUtil.inputModulus(-170.0 + 360.0 - 170.0, -180.0, 180.0));
+ assertEquals(20.0, MathUtil.inputModulus(-170.0 - (170.0 + 360.0), -180.0, 180.0));
+
+ // Test range start at zero
+ assertEquals(340.0, MathUtil.inputModulus(170.0 - 190.0, 0.0, 360.0));
+ assertEquals(340.0, MathUtil.inputModulus(170.0 + 360.0 - 190.0, 0.0, 360.0));
+ assertEquals(340.0, MathUtil.inputModulus(170.0 - (190.0 + 360), 0.0, 360.0));
+
+ // Test asymmetric range that doesn't start at zero
+ assertEquals(-20.0, MathUtil.inputModulus(170.0 - (-170.0), -170.0, 190.0));
+
+ // Test range with both positive endpoints
+ assertEquals(2.0, MathUtil.inputModulus(0.0, 1.0, 3.0));
+ assertEquals(3.0, MathUtil.inputModulus(1.0, 1.0, 3.0));
+ assertEquals(2.0, MathUtil.inputModulus(2.0, 1.0, 3.0));
+ assertEquals(3.0, MathUtil.inputModulus(3.0, 1.0, 3.0));
+ assertEquals(2.0, MathUtil.inputModulus(4.0, 1.0, 3.0));
+ }
+
+ @Test
+ void testAngleModulus() {
+ assertEquals(MathUtil.angleModulus(Math.toRadians(-2000)), Math.toRadians(160), 1e-6);
+ assertEquals(MathUtil.angleModulus(Math.toRadians(358)), Math.toRadians(-2), 1e-6);
+ assertEquals(MathUtil.angleModulus(Math.toRadians(360)), 0, 1e-6);
+
+ assertEquals(MathUtil.angleModulus(5 * Math.PI), Math.PI);
+ assertEquals(MathUtil.angleModulus(-5 * Math.PI), Math.PI);
+ assertEquals(MathUtil.angleModulus(Math.PI / 2), Math.PI / 2);
+ assertEquals(MathUtil.angleModulus(-Math.PI / 2), -Math.PI / 2);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/MatrixTest.java b/wpimath/src/test/java/edu/wpi/first/math/MatrixTest.java
new file mode 100644
index 0000000..45df8fe
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/MatrixTest.java
@@ -0,0 +1,134 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.math.numbers.N4;
+import org.ejml.data.SingularMatrixException;
+import org.junit.jupiter.api.Test;
+
+public class MatrixTest {
+ @Test
+ void testMatrixMultiplication() {
+ var mat1 = Matrix.mat(Nat.N2(), Nat.N2()).fill(2.0, 1.0, 0.0, 1.0);
+ var mat2 = Matrix.mat(Nat.N2(), Nat.N2()).fill(3.0, 0.0, 0.0, 2.5);
+
+ Matrix<N2, N2> result = mat1.times(mat2);
+
+ assertEquals(result, Matrix.mat(Nat.N2(), Nat.N2()).fill(6.0, 2.5, 0.0, 2.5));
+
+ var mat3 = Matrix.mat(Nat.N2(), Nat.N3()).fill(1.0, 3.0, 0.5, 2.0, 4.3, 1.2);
+ var mat4 =
+ Matrix.mat(Nat.N3(), Nat.N4())
+ .fill(3.0, 1.5, 2.0, 4.5, 2.3, 1.0, 1.6, 3.1, 5.2, 2.1, 2.0, 1.0);
+
+ Matrix<N2, N4> result2 = mat3.times(mat4);
+
+ assertTrue(
+ Matrix.mat(Nat.N2(), Nat.N4())
+ .fill(12.5, 5.55, 7.8, 14.3, 22.13, 9.82, 13.28, 23.53)
+ .isEqual(result2, 1E-9));
+ }
+
+ @Test
+ void testMatrixVectorMultiplication() {
+ var mat = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 1.0, 0.0, 1.0);
+
+ var vec = VecBuilder.fill(3.0, 2.0);
+
+ Matrix<N2, N1> result = mat.times(vec);
+ assertEquals(VecBuilder.fill(5.0, 2.0), result);
+ }
+
+ @Test
+ void testTranspose() {
+ Matrix<N3, N1> vec = VecBuilder.fill(1.0, 2.0, 3.0);
+
+ Matrix<N1, N3> transpose = vec.transpose();
+
+ assertEquals(Matrix.mat(Nat.N1(), Nat.N3()).fill(1.0, 2.0, 3.0), transpose);
+ }
+
+ @Test
+ void testSolve() {
+ var mat1 = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 2.0, 3.0, 4.0);
+ var vec1 = VecBuilder.fill(1.0, 2.0);
+
+ var solve1 = mat1.solve(vec1);
+
+ assertEquals(VecBuilder.fill(0.0, 0.5), solve1);
+
+ var mat2 = Matrix.mat(Nat.N3(), Nat.N2()).fill(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
+ var vec2 = VecBuilder.fill(1.0, 2.0, 3.0);
+
+ var solve2 = mat2.solve(vec2);
+
+ assertEquals(VecBuilder.fill(0.0, 0.5), solve2);
+ }
+
+ @Test
+ void testInverse() {
+ var mat = Matrix.mat(Nat.N3(), Nat.N3()).fill(1.0, 3.0, 2.0, 5.0, 2.0, 1.5, 0.0, 1.3, 2.5);
+
+ var inv = mat.inv();
+
+ assertTrue(Matrix.eye(Nat.N3()).isEqual(mat.times(inv), 1E-9));
+
+ assertTrue(Matrix.eye(Nat.N3()).isEqual(inv.times(mat), 1E-9));
+ }
+
+ @Test
+ void testUninvertableMatrix() {
+ var singularMatrix = Matrix.mat(Nat.N2(), Nat.N2()).fill(2.0, 1.0, 2.0, 1.0);
+
+ assertThrows(SingularMatrixException.class, singularMatrix::inv);
+ }
+
+ @Test
+ void testMatrixScalarArithmetic() {
+ var mat = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 2.0, 3.0, 4.0);
+
+ assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(3.0, 4.0, 5.0, 6.0), mat.plus(2.0));
+
+ assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 1.0, 2.0, 3.0), mat.minus(1.0));
+
+ assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(2.0, 4.0, 6.0, 8.0), mat.times(2.0));
+
+ assertTrue(Matrix.mat(Nat.N2(), Nat.N2()).fill(0.5, 1.0, 1.5, 2.0).isEqual(mat.div(2.0), 1E-3));
+ }
+
+ @Test
+ void testMatrixMatrixArithmetic() {
+ var mat1 = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 2.0, 3.0, 4.0);
+
+ var mat2 = Matrix.mat(Nat.N2(), Nat.N2()).fill(5.0, 6.0, 7.0, 8.0);
+
+ assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(-4.0, -4.0, -4.0, -4.0), mat1.minus(mat2));
+
+ assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(6.0, 8.0, 10.0, 12.0), mat1.plus(mat2));
+ }
+
+ @Test
+ void testMatrixExponential() {
+ var matrix = Matrix.eye(Nat.N2());
+ var result = matrix.exp();
+
+ assertTrue(result.isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(Math.E, 0, 0, Math.E), 1E-9));
+
+ matrix = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 2, 3, 4);
+ result = matrix.times(0.01).exp();
+
+ assertTrue(
+ result.isEqual(
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(1.01035625, 0.02050912, 0.03076368, 1.04111993),
+ 1E-8));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/StateSpaceUtilTest.java b/wpimath/src/test/java/edu/wpi/first/math/StateSpaceUtilTest.java
new file mode 100644
index 0000000..456506c
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/StateSpaceUtilTest.java
@@ -0,0 +1,185 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import java.util.ArrayList;
+import java.util.List;
+import org.ejml.dense.row.MatrixFeatures_DDRM;
+import org.ejml.simple.SimpleMatrix;
+import org.junit.jupiter.api.Test;
+
+public class StateSpaceUtilTest {
+ @Test
+ public void testCostArray() {
+ var mat = StateSpaceUtil.makeCostMatrix(VecBuilder.fill(1.0, 2.0, 3.0));
+
+ assertEquals(1.0, mat.get(0, 0), 1e-3);
+ assertEquals(0.0, mat.get(0, 1), 1e-3);
+ assertEquals(0.0, mat.get(0, 2), 1e-3);
+ assertEquals(0.0, mat.get(1, 0), 1e-3);
+ assertEquals(1.0 / 4.0, mat.get(1, 1), 1e-3);
+ assertEquals(0.0, mat.get(1, 2), 1e-3);
+ assertEquals(0.0, mat.get(0, 2), 1e-3);
+ assertEquals(0.0, mat.get(1, 2), 1e-3);
+ assertEquals(1.0 / 9.0, mat.get(2, 2), 1e-3);
+ }
+
+ @Test
+ public void testCovArray() {
+ var mat = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), VecBuilder.fill(1.0, 2.0, 3.0));
+
+ assertEquals(1.0, mat.get(0, 0), 1e-3);
+ assertEquals(0.0, mat.get(0, 1), 1e-3);
+ assertEquals(0.0, mat.get(0, 2), 1e-3);
+ assertEquals(0.0, mat.get(1, 0), 1e-3);
+ assertEquals(4.0, mat.get(1, 1), 1e-3);
+ assertEquals(0.0, mat.get(1, 2), 1e-3);
+ assertEquals(0.0, mat.get(0, 2), 1e-3);
+ assertEquals(0.0, mat.get(1, 2), 1e-3);
+ assertEquals(9.0, mat.get(2, 2), 1e-3);
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testIsStabilizable() {
+ Matrix<N2, N2> A;
+ Matrix<N2, N1> B = VecBuilder.fill(0, 1);
+
+ // First eigenvalue is uncontrollable and unstable.
+ // Second eigenvalue is controllable and stable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.2, 0, 0, 0.5);
+ assertFalse(StateSpaceUtil.isStabilizable(A, B));
+
+ // First eigenvalue is uncontrollable and marginally stable.
+ // Second eigenvalue is controllable and stable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 0.5);
+ assertFalse(StateSpaceUtil.isStabilizable(A, B));
+
+ // First eigenvalue is uncontrollable and stable.
+ // Second eigenvalue is controllable and stable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.2, 0, 0, 0.5);
+ assertTrue(StateSpaceUtil.isStabilizable(A, B));
+
+ // First eigenvalue is uncontrollable and stable.
+ // Second eigenvalue is controllable and unstable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.2, 0, 0, 1.2);
+ assertTrue(StateSpaceUtil.isStabilizable(A, B));
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testIsDetectable() {
+ Matrix<N2, N2> A;
+ Matrix<N1, N2> C = Matrix.mat(Nat.N1(), Nat.N2()).fill(0, 1);
+
+ // First eigenvalue is unobservable and unstable.
+ // Second eigenvalue is observable and stable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.2, 0, 0, 0.5);
+ assertFalse(StateSpaceUtil.isDetectable(A, C));
+
+ // First eigenvalue is unobservable and marginally stable.
+ // Second eigenvalue is observable and stable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 0.5);
+ assertFalse(StateSpaceUtil.isDetectable(A, C));
+
+ // First eigenvalue is unobservable and stable.
+ // Second eigenvalue is observable and stable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.2, 0, 0, 0.5);
+ assertTrue(StateSpaceUtil.isDetectable(A, C));
+
+ // First eigenvalue is unobservable and stable.
+ // Second eigenvalue is observable and unstable.
+ A = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.2, 0, 0, 1.2);
+ assertTrue(StateSpaceUtil.isDetectable(A, C));
+ }
+
+ @Test
+ public void testMakeWhiteNoiseVector() {
+ var firstData = new ArrayList<Double>();
+ var secondData = new ArrayList<Double>();
+ for (int i = 0; i < 1000; i++) {
+ var noiseVec = StateSpaceUtil.makeWhiteNoiseVector(VecBuilder.fill(1.0, 2.0));
+ firstData.add(noiseVec.get(0, 0));
+ secondData.add(noiseVec.get(1, 0));
+ }
+ assertEquals(1.0, calculateStandardDeviation(firstData), 0.2);
+ assertEquals(2.0, calculateStandardDeviation(secondData), 0.2);
+ }
+
+ private double calculateStandardDeviation(List<Double> numArray) {
+ double sum = 0.0;
+ double standardDeviation = 0.0;
+ int length = numArray.size();
+
+ for (double num : numArray) {
+ sum += num;
+ }
+
+ double mean = sum / length;
+
+ for (double num : numArray) {
+ standardDeviation += Math.pow(num - mean, 2);
+ }
+
+ return Math.sqrt(standardDeviation / length);
+ }
+
+ @Test
+ public void testMatrixExp() {
+ Matrix<N2, N2> wrappedMatrix = Matrix.eye(Nat.N2());
+ var wrappedResult = wrappedMatrix.exp();
+
+ assertTrue(
+ wrappedResult.isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(Math.E, 0, 0, Math.E), 1E-9));
+
+ var matrix = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 2, 3, 4);
+ wrappedResult = matrix.times(0.01).exp();
+
+ assertTrue(
+ wrappedResult.isEqual(
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(1.01035625, 0.02050912, 0.03076368, 1.04111993),
+ 1E-8));
+ }
+
+ @Test
+ public void testSimpleMatrixExp() {
+ SimpleMatrix matrix = SimpleMatrixUtils.eye(2);
+ var result = SimpleMatrixUtils.exp(matrix);
+
+ assertTrue(
+ MatrixFeatures_DDRM.isIdentical(
+ result.getDDRM(),
+ new SimpleMatrix(2, 2, true, new double[] {Math.E, 0, 0, Math.E}).getDDRM(),
+ 1E-9));
+
+ matrix = new SimpleMatrix(2, 2, true, new double[] {1, 2, 3, 4});
+ result = SimpleMatrixUtils.exp(matrix.scale(0.01));
+
+ assertTrue(
+ MatrixFeatures_DDRM.isIdentical(
+ result.getDDRM(),
+ new SimpleMatrix(
+ 2, 2, true, new double[] {1.01035625, 0.02050912, 0.03076368, 1.04111993})
+ .getDDRM(),
+ 1E-8));
+ }
+
+ @Test
+ public void testPoseToVector() {
+ Pose2d pose = new Pose2d(1, 2, new Rotation2d(3));
+ var vector = StateSpaceUtil.poseToVector(pose);
+ assertEquals(pose.getTranslation().getX(), vector.get(0, 0), 1e-6);
+ assertEquals(pose.getTranslation().getY(), vector.get(1, 0), 1e-6);
+ assertEquals(pose.getRotation().getRadians(), vector.get(2, 0), 1e-6);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/WPIMathJNITest.java b/wpimath/src/test/java/edu/wpi/first/math/WPIMathJNITest.java
index 10ec44c..6a1bf2e 100644
--- a/wpimath/src/test/java/edu/wpi/first/math/WPIMathJNITest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/WPIMathJNITest.java
@@ -1,16 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math;
-import org.junit.jupiter.api.Test;
-
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import org.junit.jupiter.api.Test;
+
public class WPIMathJNITest {
@Test
public void testLink() {
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/ControlAffinePlantInversionFeedforwardTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/ControlAffinePlantInversionFeedforwardTest.java
new file mode 100644
index 0000000..f64608e
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/ControlAffinePlantInversionFeedforwardTest.java
@@ -0,0 +1,51 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import org.junit.jupiter.api.Test;
+
+class ControlAffinePlantInversionFeedforwardTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testCalculate() {
+ ControlAffinePlantInversionFeedforward<N2, N1> feedforward =
+ new ControlAffinePlantInversionFeedforward<N2, N1>(
+ Nat.N2(), Nat.N1(), this::getDynamics, 0.02);
+
+ assertEquals(
+ 48.0, feedforward.calculate(VecBuilder.fill(2, 2), VecBuilder.fill(3, 3)).get(0, 0), 1e-6);
+ }
+
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testCalculateState() {
+ ControlAffinePlantInversionFeedforward<N2, N1> feedforward =
+ new ControlAffinePlantInversionFeedforward<N2, N1>(
+ Nat.N2(), Nat.N1(), this::getDynamics, 0.02);
+
+ assertEquals(
+ 48.0, feedforward.calculate(VecBuilder.fill(2, 2), VecBuilder.fill(3, 3)).get(0, 0), 1e-6);
+ }
+
+ @SuppressWarnings("ParameterName")
+ protected Matrix<N2, N1> getDynamics(Matrix<N2, N1> x, Matrix<N1, N1> u) {
+ return Matrix.mat(Nat.N2(), Nat.N2())
+ .fill(1.000, 0, 0, 1.000)
+ .times(x)
+ .plus(VecBuilder.fill(0, 1).times(u));
+ }
+
+ @SuppressWarnings("ParameterName")
+ protected Matrix<N2, N1> getStateDynamics(Matrix<N2, N1> x) {
+ return Matrix.mat(Nat.N2(), Nat.N2()).fill(1.000, 0, 0, 1.000).times(x);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/HolonomicDriveControllerTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/HolonomicDriveControllerTest.java
new file mode 100644
index 0000000..db3f6dc
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/HolonomicDriveControllerTest.java
@@ -0,0 +1,90 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.MathUtil;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Twist2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.trajectory.Trajectory;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import edu.wpi.first.math.trajectory.TrapezoidProfile;
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class HolonomicDriveControllerTest {
+ private static final double kTolerance = 1 / 12.0;
+ private static final double kAngularTolerance = Math.toRadians(2);
+
+ @Test
+ void testReachesReference() {
+ HolonomicDriveController controller =
+ new HolonomicDriveController(
+ new PIDController(1.0, 0.0, 0.0),
+ new PIDController(1.0, 0.0, 0.0),
+ new ProfiledPIDController(
+ 1.0, 0.0, 0.0, new TrapezoidProfile.Constraints(2.0 * Math.PI, Math.PI)));
+ Pose2d robotPose = new Pose2d(2.7, 23.0, Rotation2d.fromDegrees(0.0));
+
+ List<Pose2d> waypoints = new ArrayList<>();
+ waypoints.add(new Pose2d(2.75, 22.521, new Rotation2d(0)));
+ waypoints.add(new Pose2d(24.73, 19.68, new Rotation2d(5.8)));
+
+ TrajectoryConfig config = new TrajectoryConfig(8.0, 4.0);
+ Trajectory trajectory = TrajectoryGenerator.generateTrajectory(waypoints, config);
+
+ final double kDt = 0.02;
+ final double kTotalTime = trajectory.getTotalTimeSeconds();
+
+ for (int i = 0; i < (kTotalTime / kDt); i++) {
+ Trajectory.State state = trajectory.sample(kDt * i);
+ ChassisSpeeds output = controller.calculate(robotPose, state, new Rotation2d());
+
+ robotPose =
+ robotPose.exp(
+ new Twist2d(
+ output.vxMetersPerSecond * kDt,
+ output.vyMetersPerSecond * kDt,
+ output.omegaRadiansPerSecond * kDt));
+ }
+
+ final List<Trajectory.State> states = trajectory.getStates();
+ final Pose2d endPose = states.get(states.size() - 1).poseMeters;
+
+ // Java lambdas require local variables referenced from a lambda expression
+ // must be final or effectively final.
+ final Pose2d finalRobotPose = robotPose;
+
+ assertAll(
+ () -> assertEquals(endPose.getX(), finalRobotPose.getX(), kTolerance),
+ () -> assertEquals(endPose.getY(), finalRobotPose.getY(), kTolerance),
+ () ->
+ assertEquals(
+ 0.0,
+ MathUtil.angleModulus(finalRobotPose.getRotation().getRadians()),
+ kAngularTolerance));
+ }
+
+ @Test
+ void testDoesNotRotateUnnecessarily() {
+ var controller =
+ new HolonomicDriveController(
+ new PIDController(1, 0, 0),
+ new PIDController(1, 0, 0),
+ new ProfiledPIDController(1, 0, 0, new TrapezoidProfile.Constraints(4, 2)));
+
+ ChassisSpeeds speeds =
+ controller.calculate(
+ new Pose2d(0, 0, new Rotation2d(1.57)), new Pose2d(), 0, new Rotation2d(1.57));
+
+ assertEquals(0.0, speeds.omegaRadiansPerSecond);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/LinearPlantInversionFeedforwardTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/LinearPlantInversionFeedforwardTest.java
new file mode 100644
index 0000000..98b0e6c
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/LinearPlantInversionFeedforwardTest.java
@@ -0,0 +1,31 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import org.junit.jupiter.api.Test;
+
+class LinearPlantInversionFeedforwardTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testCalculate() {
+ Matrix<N2, N2> A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1);
+ Matrix<N2, N1> B = VecBuilder.fill(0, 1);
+
+ LinearPlantInversionFeedforward<N2, N1, N1> feedforward =
+ new LinearPlantInversionFeedforward<N2, N1, N1>(A, B, 0.02);
+
+ assertEquals(
+ 47.502599,
+ feedforward.calculate(VecBuilder.fill(2, 2), VecBuilder.fill(3, 3)).get(0, 0),
+ 0.002);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/LinearQuadraticRegulatorTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/LinearQuadraticRegulatorTest.java
new file mode 100644
index 0000000..5d9c2b8
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/LinearQuadraticRegulatorTest.java
@@ -0,0 +1,92 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.system.plant.DCMotor;
+import edu.wpi.first.math.system.plant.LinearSystemId;
+import org.junit.jupiter.api.Test;
+
+public class LinearQuadraticRegulatorTest {
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testLQROnElevator() {
+ var motors = DCMotor.getVex775Pro(2);
+
+ var m = 5.0;
+ var r = 0.0181864;
+ var G = 1.0;
+
+ var plant = LinearSystemId.createElevatorSystem(motors, m, r, G);
+
+ var qElms = VecBuilder.fill(0.02, 0.4);
+ var rElms = VecBuilder.fill(12.0);
+ var dt = 0.00505;
+
+ var controller = new LinearQuadraticRegulator<>(plant, qElms, rElms, dt);
+
+ var k = controller.getK();
+
+ assertEquals(522.153, k.get(0, 0), 0.1);
+ assertEquals(38.2, k.get(0, 1), 0.1);
+ }
+
+ @Test
+ public void testFourMotorElevator() {
+ var dt = 0.020;
+
+ var plant =
+ LinearSystemId.createElevatorSystem(
+ DCMotor.getVex775Pro(4), 8.0, 0.75 * 25.4 / 1000.0, 14.67);
+
+ var regulator =
+ new LinearQuadraticRegulator<>(plant, VecBuilder.fill(0.1, 0.2), VecBuilder.fill(12.0), dt);
+
+ assertEquals(10.381, regulator.getK().get(0, 0), 1e-2);
+ assertEquals(0.6929, regulator.getK().get(0, 1), 1e-2);
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testLQROnArm() {
+ var motors = DCMotor.getVex775Pro(2);
+
+ var m = 4.0;
+ var r = 0.4;
+ var G = 100.0;
+
+ var plant = LinearSystemId.createSingleJointedArmSystem(motors, 1d / 3d * m * r * r, G);
+
+ var qElms = VecBuilder.fill(0.01745, 0.08726);
+ var rElms = VecBuilder.fill(12.0);
+ var dt = 0.00505;
+
+ var controller = new LinearQuadraticRegulator<>(plant, qElms, rElms, dt);
+
+ var k = controller.getK();
+
+ assertEquals(19.16, k.get(0, 0), 0.1);
+ assertEquals(3.32, k.get(0, 1), 0.1);
+ }
+
+ @Test
+ public void testLatencyCompensate() {
+ var dt = 0.02;
+
+ var plant =
+ LinearSystemId.createElevatorSystem(
+ DCMotor.getVex775Pro(4), 8.0, 0.75 * 25.4 / 1000.0, 14.67);
+
+ var regulator =
+ new LinearQuadraticRegulator<>(plant, VecBuilder.fill(0.1, 0.2), VecBuilder.fill(12.0), dt);
+
+ regulator.latencyCompensate(plant, dt, 0.01);
+
+ assertEquals(8.97115941, regulator.getK().get(0, 0), 1e-3);
+ assertEquals(0.07904881, regulator.getK().get(0, 1), 1e-3);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/LinearSystemLoopTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/LinearSystemLoopTest.java
new file mode 100644
index 0000000..597ae7f
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/LinearSystemLoopTest.java
@@ -0,0 +1,128 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.estimator.KalmanFilter;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.system.LinearSystem;
+import edu.wpi.first.math.system.LinearSystemLoop;
+import edu.wpi.first.math.system.plant.DCMotor;
+import edu.wpi.first.math.system.plant.LinearSystemId;
+import edu.wpi.first.math.trajectory.TrapezoidProfile;
+import java.util.Random;
+import org.junit.jupiter.api.Test;
+
+public class LinearSystemLoopTest {
+ public static final double kDt = 0.00505;
+ private static final double kPositionStddev = 0.0001;
+ private static final Random random = new Random();
+
+ LinearSystem<N2, N1, N1> m_plant =
+ LinearSystemId.createElevatorSystem(DCMotor.getVex775Pro(2), 5, 0.0181864, 1.0);
+
+ KalmanFilter<N2, N1, N1> m_observer =
+ new KalmanFilter<>(
+ Nat.N2(), Nat.N1(), m_plant, VecBuilder.fill(0.05, 1.0), VecBuilder.fill(0.0001), kDt);
+
+ LinearQuadraticRegulator<N2, N1, N1> m_controller =
+ new LinearQuadraticRegulator<>(
+ m_plant, VecBuilder.fill(0.02, 0.4), VecBuilder.fill(12.0), 0.00505);
+
+ private final LinearSystemLoop<N2, N1, N1> m_loop =
+ new LinearSystemLoop<>(m_plant, m_controller, m_observer, 12, 0.00505);
+
+ @SuppressWarnings("LocalVariableName")
+ private static void updateTwoState(
+ LinearSystem<N2, N1, N1> plant, LinearSystemLoop<N2, N1, N1> loop, double noise) {
+ Matrix<N1, N1> y = plant.calculateY(loop.getXHat(), loop.getU()).plus(VecBuilder.fill(noise));
+
+ loop.correct(y);
+ loop.predict(kDt);
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testStateSpaceEnabled() {
+ m_loop.reset(VecBuilder.fill(0, 0));
+ Matrix<N2, N1> references = VecBuilder.fill(2.0, 0.0);
+ var constraints = new TrapezoidProfile.Constraints(4, 3);
+ m_loop.setNextR(references);
+
+ TrapezoidProfile profile;
+ TrapezoidProfile.State state;
+ for (int i = 0; i < 1000; i++) {
+ profile =
+ new TrapezoidProfile(
+ constraints,
+ new TrapezoidProfile.State(m_loop.getXHat(0), m_loop.getXHat(1)),
+ new TrapezoidProfile.State(references.get(0, 0), references.get(1, 0)));
+ state = profile.calculate(kDt);
+ m_loop.setNextR(VecBuilder.fill(state.position, state.velocity));
+
+ updateTwoState(m_plant, m_loop, (random.nextGaussian()) * kPositionStddev);
+ var u = m_loop.getU(0);
+
+ assertTrue(u >= -12.1 && u <= 12.1, "U out of bounds! Got " + u);
+ }
+
+ assertEquals(2.0, m_loop.getXHat(0), 0.05);
+ assertEquals(0.0, m_loop.getXHat(1), 0.5);
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testFlywheelEnabled() {
+ LinearSystem<N1, N1, N1> plant =
+ LinearSystemId.createFlywheelSystem(DCMotor.getNEO(2), 0.00289, 1.0);
+ KalmanFilter<N1, N1, N1> observer =
+ new KalmanFilter<>(
+ Nat.N1(), Nat.N1(), plant, VecBuilder.fill(1.0), VecBuilder.fill(kPositionStddev), kDt);
+
+ var qElms = VecBuilder.fill(9.0);
+ var rElms = VecBuilder.fill(12.0);
+
+ var controller = new LinearQuadraticRegulator<>(plant, qElms, rElms, kDt);
+
+ var feedforward = new LinearPlantInversionFeedforward<>(plant, kDt);
+
+ var loop = new LinearSystemLoop<>(controller, feedforward, observer, 12);
+
+ loop.reset(VecBuilder.fill(0.0));
+ var references = VecBuilder.fill(3000 / 60d * 2 * Math.PI);
+
+ loop.setNextR(references);
+
+ var time = 0.0;
+ while (time < 10) {
+ if (time > 10) {
+ break;
+ }
+
+ loop.setNextR(references);
+
+ Matrix<N1, N1> y =
+ plant
+ .calculateY(loop.getXHat(), loop.getU())
+ .plus(VecBuilder.fill(random.nextGaussian() * kPositionStddev));
+
+ loop.correct(y);
+ loop.predict(kDt);
+ var u = loop.getU(0);
+
+ assertTrue(u >= -12.1 && u <= 12.1, "U out of bounds! Got " + u);
+
+ time += kDt;
+ }
+
+ assertEquals(0.0, loop.getError(0), 0.1);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/PIDInputOutputTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/PIDInputOutputTest.java
new file mode 100644
index 0000000..1fe4cb1
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/PIDInputOutputTest.java
@@ -0,0 +1,58 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+class PIDInputOutputTest {
+ private PIDController m_controller;
+
+ @BeforeEach
+ void setUp() {
+ m_controller = new PIDController(0, 0, 0);
+ }
+
+ @Test
+ void continuousInputTest() {
+ m_controller.setP(1);
+ m_controller.enableContinuousInput(-180, 180);
+ assertEquals(m_controller.calculate(-179, 179), -2, 1e-5);
+
+ m_controller.enableContinuousInput(0, 360);
+ assertEquals(m_controller.calculate(1, 359), -2, 1e-5);
+ }
+
+ @Test
+ void proportionalGainOutputTest() {
+ m_controller.setP(4);
+
+ assertEquals(-0.1, m_controller.calculate(0.025, 0), 1e-5);
+ }
+
+ @Test
+ void integralGainOutputTest() {
+ m_controller.setI(4);
+
+ double out = 0;
+
+ for (int i = 0; i < 5; i++) {
+ out = m_controller.calculate(0.025, 0);
+ }
+
+ assertEquals(-0.5 * m_controller.getPeriod(), out, 1e-5);
+ }
+
+ @Test
+ void derivativeGainOutputTest() {
+ m_controller.setD(4);
+
+ m_controller.calculate(0, 0);
+
+ assertEquals(-0.01 / m_controller.getPeriod(), m_controller.calculate(0.0025, 0), 1e-5);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/PIDToleranceTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/PIDToleranceTest.java
new file mode 100644
index 0000000..b525f49
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/PIDToleranceTest.java
@@ -0,0 +1,66 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import org.junit.jupiter.api.Test;
+
+class PIDToleranceTest {
+ private static final double kSetpoint = 50.0;
+ private static final double kTolerance = 10.0;
+ private static final double kRange = 200;
+
+ @Test
+ void initialToleranceTest() {
+ try (var controller = new PIDController(0.05, 0.0, 0.0)) {
+ controller.enableContinuousInput(-kRange / 2, kRange / 2);
+
+ assertTrue(controller.atSetpoint());
+ }
+ }
+
+ @Test
+ void absoluteToleranceTest() {
+ try (var controller = new PIDController(0.05, 0.0, 0.0)) {
+ controller.enableContinuousInput(-kRange / 2, kRange / 2);
+
+ assertTrue(
+ controller.atSetpoint(),
+ "Error was not in tolerance when it should have been. Error was "
+ + controller.getPositionError());
+
+ controller.setTolerance(kTolerance);
+ controller.setSetpoint(kSetpoint);
+
+ assertFalse(
+ controller.atSetpoint(),
+ "Error was in tolerance when it should not have been. Error was "
+ + controller.getPositionError());
+
+ controller.calculate(0.0);
+
+ assertFalse(
+ controller.atSetpoint(),
+ "Error was in tolerance when it should not have been. Error was "
+ + controller.getPositionError());
+
+ controller.calculate(kSetpoint + kTolerance / 2);
+
+ assertTrue(
+ controller.atSetpoint(),
+ "Error was not in tolerance when it should have been. Error was "
+ + controller.getPositionError());
+
+ controller.calculate(kSetpoint + 10 * kTolerance);
+
+ assertFalse(
+ controller.atSetpoint(),
+ "Error was in tolerance when it should not have been. Error was "
+ + controller.getPositionError());
+ }
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/ProfiledPIDControllerTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/ProfiledPIDControllerTest.java
new file mode 100644
index 0000000..d87943a
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/ProfiledPIDControllerTest.java
@@ -0,0 +1,22 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.trajectory.TrapezoidProfile;
+import org.junit.jupiter.api.Test;
+
+class ProfiledPIDControllerTest {
+ @Test
+ void testStartFromNonZeroPosition() {
+ ProfiledPIDController controller =
+ new ProfiledPIDController(1.0, 0.0, 0.0, new TrapezoidProfile.Constraints(1.0, 1.0));
+
+ controller.reset(20);
+
+ assertEquals(0.0, controller.calculate(20), 0.05);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/ProfiledPIDInputOutputTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/ProfiledPIDInputOutputTest.java
new file mode 100644
index 0000000..e0a4945
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/ProfiledPIDInputOutputTest.java
@@ -0,0 +1,114 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.trajectory.TrapezoidProfile;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+class ProfiledPIDInputOutputTest {
+ private ProfiledPIDController m_controller;
+
+ @BeforeEach
+ void setUp() {
+ m_controller = new ProfiledPIDController(0, 0, 0, new TrapezoidProfile.Constraints(360, 180));
+ }
+
+ @Test
+ void continuousInputTest1() {
+ m_controller.setP(1);
+ m_controller.enableContinuousInput(-180, 180);
+
+ final double kSetpoint = -179.0;
+ final double kMeasurement = -179.0;
+ final double kGoal = 179.0;
+
+ m_controller.reset(kSetpoint);
+ assertTrue(m_controller.calculate(kMeasurement, kGoal) < 0.0);
+
+ // Error must be less than half the input range at all times
+ assertTrue(Math.abs(m_controller.getSetpoint().position - kMeasurement) < 180.0);
+ }
+
+ @Test
+ void continuousInputTest2() {
+ m_controller.setP(1);
+ m_controller.enableContinuousInput(-Math.PI, Math.PI);
+
+ final double kSetpoint = -3.4826633343199735;
+ final double kMeasurement = -3.1352207333939606;
+ final double kGoal = -3.534162788601621;
+
+ m_controller.reset(kSetpoint);
+ assertTrue(m_controller.calculate(kMeasurement, kGoal) < 0.0);
+
+ // Error must be less than half the input range at all times
+ assertTrue(Math.abs(m_controller.getSetpoint().position - kMeasurement) < Math.PI);
+ }
+
+ @Test
+ void continuousInputTest3() {
+ m_controller.setP(1);
+ m_controller.enableContinuousInput(-Math.PI, Math.PI);
+
+ final double kSetpoint = -3.5176604690006377;
+ final double kMeasurement = 3.1191729343822456;
+ final double kGoal = 2.709680418117445;
+
+ m_controller.reset(kSetpoint);
+ assertTrue(m_controller.calculate(kMeasurement, kGoal) < 0.0);
+
+ // Error must be less than half the input range at all times
+ assertTrue(Math.abs(m_controller.getSetpoint().position - kMeasurement) < Math.PI);
+ }
+
+ @Test
+ void continuousInputTest4() {
+ m_controller.setP(1);
+ m_controller.enableContinuousInput(0, 2.0 * Math.PI);
+
+ final double kSetpoint = 2.78;
+ final double kMeasurement = 3.12;
+ final double kGoal = 2.71;
+
+ m_controller.reset(kSetpoint);
+ assertTrue(m_controller.calculate(kMeasurement, kGoal) < 0.0);
+
+ // Error must be less than half the input range at all times
+ assertTrue(Math.abs(m_controller.getSetpoint().position - kMeasurement) < Math.PI / 2.0);
+ }
+
+ @Test
+ void proportionalGainOutputTest() {
+ m_controller.setP(4);
+
+ assertEquals(-0.1, m_controller.calculate(0.025, 0), 1e-5);
+ }
+
+ @Test
+ void integralGainOutputTest() {
+ m_controller.setI(4);
+
+ double out = 0;
+
+ for (int i = 0; i < 5; i++) {
+ out = m_controller.calculate(0.025, 0);
+ }
+
+ assertEquals(-0.5 * m_controller.getPeriod(), out, 1e-5);
+ }
+
+ @Test
+ void derivativeGainOutputTest() {
+ m_controller.setD(4);
+
+ m_controller.calculate(0, 0);
+
+ assertEquals(-0.01 / m_controller.getPeriod(), m_controller.calculate(0.0025, 0), 1e-5);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/RamseteControllerTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/RamseteControllerTest.java
new file mode 100644
index 0000000..813bf42
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/RamseteControllerTest.java
@@ -0,0 +1,61 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.MathUtil;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Twist2d;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.ArrayList;
+import org.junit.jupiter.api.Test;
+
+class RamseteControllerTest {
+ private static final double kTolerance = 1 / 12.0;
+ private static final double kAngularTolerance = Math.toRadians(2);
+
+ @Test
+ void testReachesReference() {
+ final var controller = new RamseteController(2.0, 0.7);
+ var robotPose = new Pose2d(2.7, 23.0, Rotation2d.fromDegrees(0.0));
+
+ final var waypoints = new ArrayList<Pose2d>();
+ waypoints.add(new Pose2d(2.75, 22.521, new Rotation2d(0)));
+ waypoints.add(new Pose2d(24.73, 19.68, new Rotation2d(5.846)));
+ var config = new TrajectoryConfig(8.8, 0.1);
+ final var trajectory = TrajectoryGenerator.generateTrajectory(waypoints, config);
+
+ final double kDt = 0.02;
+ final var totalTime = trajectory.getTotalTimeSeconds();
+ for (int i = 0; i < (totalTime / kDt); ++i) {
+ var state = trajectory.sample(kDt * i);
+
+ var output = controller.calculate(robotPose, state);
+ robotPose =
+ robotPose.exp(
+ new Twist2d(output.vxMetersPerSecond * kDt, 0, output.omegaRadiansPerSecond * kDt));
+ }
+
+ final var states = trajectory.getStates();
+ final var endPose = states.get(states.size() - 1).poseMeters;
+
+ // Java lambdas require local variables referenced from a lambda expression
+ // must be final or effectively final.
+ final var finalRobotPose = robotPose;
+ assertAll(
+ () -> assertEquals(endPose.getX(), finalRobotPose.getX(), kTolerance),
+ () -> assertEquals(endPose.getY(), finalRobotPose.getY(), kTolerance),
+ () ->
+ assertEquals(
+ 0.0,
+ MathUtil.angleModulus(
+ endPose.getRotation().getRadians() - finalRobotPose.getRotation().getRadians()),
+ kAngularTolerance));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/controller/SimpleMotorFeedforwardTest.java b/wpimath/src/test/java/edu/wpi/first/math/controller/SimpleMotorFeedforwardTest.java
new file mode 100644
index 0000000..83fced2
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/controller/SimpleMotorFeedforwardTest.java
@@ -0,0 +1,46 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.numbers.N1;
+import org.junit.jupiter.api.Test;
+
+class SimpleMotorFeedforwardTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testCalculate() {
+ double Ks = 0.5;
+ double Kv = 3.0;
+ double Ka = 0.6;
+ double dt = 0.02;
+
+ var A = Matrix.mat(Nat.N1(), Nat.N1()).fill(-Kv / Ka);
+ var B = Matrix.mat(Nat.N1(), Nat.N1()).fill(1.0 / Ka);
+
+ var plantInversion = new LinearPlantInversionFeedforward<N1, N1, N1>(A, B, dt);
+ var simpleMotor = new SimpleMotorFeedforward(Ks, Kv, Ka);
+
+ var r = VecBuilder.fill(2.0);
+ var nextR = VecBuilder.fill(3.0);
+
+ assertEquals(37.524995834325161 + 0.5, simpleMotor.calculate(2.0, 3.0, dt), 0.002);
+ assertEquals(
+ plantInversion.calculate(r, nextR).get(0, 0) + Ks,
+ simpleMotor.calculate(2.0, 3.0, dt),
+ 0.002);
+
+ // These won't match exactly. It's just an approximation to make sure they're
+ // in the same ballpark.
+ assertEquals(
+ plantInversion.calculate(r, nextR).get(0, 0) + Ks,
+ simpleMotor.calculate(2.0, 1.0 / dt),
+ 2.0);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/AngleStatisticsTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/AngleStatisticsTest.java
new file mode 100644
index 0000000..9fcf5e3
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/AngleStatisticsTest.java
@@ -0,0 +1,44 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import org.junit.jupiter.api.Test;
+
+public class AngleStatisticsTest {
+ @Test
+ public void testMean() {
+ // 3 states, 2 sigmas
+ var sigmas =
+ Matrix.mat(Nat.N3(), Nat.N2()).fill(1, 1.2, Math.toRadians(359), Math.toRadians(3), 1, 2);
+ // Weights need to produce the mean of the sigmas
+ var weights = new Matrix<>(Nat.N2(), Nat.N1());
+ weights.fill(1.0 / sigmas.getNumCols());
+
+ assertTrue(
+ AngleStatistics.angleMean(sigmas, weights, 1)
+ .isEqual(VecBuilder.fill(1.1, Math.toRadians(1), 1.5), 1e-6));
+ }
+
+ @Test
+ public void testResidual() {
+ var first = VecBuilder.fill(1, Math.toRadians(1), 2);
+ var second = VecBuilder.fill(1, Math.toRadians(359), 1);
+ assertTrue(
+ AngleStatistics.angleResidual(first, second, 1)
+ .isEqual(VecBuilder.fill(0, Math.toRadians(2), 1), 1e-6));
+ }
+
+ @Test
+ public void testAdd() {
+ var first = VecBuilder.fill(1, Math.toRadians(1), 2);
+ var second = VecBuilder.fill(1, Math.toRadians(359), 1);
+ assertTrue(AngleStatistics.angleAdd(first, second, 1).isEqual(VecBuilder.fill(2, 0, 3), 1e-6));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/DifferentialDrivePoseEstimatorTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/DifferentialDrivePoseEstimatorTest.java
new file mode 100644
index 0000000..6e4b261
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/DifferentialDrivePoseEstimatorTest.java
@@ -0,0 +1,116 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.MatBuilder;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+import edu.wpi.first.math.kinematics.DifferentialDriveWheelSpeeds;
+import edu.wpi.first.math.trajectory.Trajectory;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.List;
+import java.util.Random;
+import org.junit.jupiter.api.Test;
+
+public class DifferentialDrivePoseEstimatorTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ public void testAccuracy() {
+ var estimator =
+ new DifferentialDrivePoseEstimator(
+ new Rotation2d(),
+ new Pose2d(),
+ new MatBuilder<>(Nat.N5(), Nat.N1()).fill(0.02, 0.02, 0.01, 0.02, 0.02),
+ new MatBuilder<>(Nat.N3(), Nat.N1()).fill(0.01, 0.01, 0.001),
+ new MatBuilder<>(Nat.N3(), Nat.N1()).fill(0.1, 0.1, 0.01));
+
+ var traj =
+ TrajectoryGenerator.generateTrajectory(
+ List.of(
+ new Pose2d(0, 0, Rotation2d.fromDegrees(45)),
+ new Pose2d(3, 0, Rotation2d.fromDegrees(-90)),
+ new Pose2d(0, 0, Rotation2d.fromDegrees(135)),
+ new Pose2d(-3, 0, Rotation2d.fromDegrees(-90)),
+ new Pose2d(0, 0, Rotation2d.fromDegrees(45))),
+ new TrajectoryConfig(10, 5));
+
+ var kinematics = new DifferentialDriveKinematics(1);
+ var rand = new Random(4915);
+
+ final double dt = 0.02;
+ double t = 0.0;
+
+ final double visionUpdateRate = 0.1;
+ Pose2d lastVisionPose = null;
+ double lastVisionUpdateTime = Double.NEGATIVE_INFINITY;
+
+ double distanceLeft = 0.0;
+ double distanceRight = 0.0;
+
+ double maxError = Double.NEGATIVE_INFINITY;
+ double errorSum = 0;
+ Trajectory.State groundtruthState;
+ DifferentialDriveWheelSpeeds input;
+ while (t <= traj.getTotalTimeSeconds()) {
+ groundtruthState = traj.sample(t);
+ input =
+ kinematics.toWheelSpeeds(
+ new ChassisSpeeds(
+ groundtruthState.velocityMetersPerSecond,
+ 0.0,
+ // ds/dt * dtheta/ds = dtheta/dt
+ groundtruthState.velocityMetersPerSecond
+ * groundtruthState.curvatureRadPerMeter));
+
+ if (lastVisionUpdateTime + visionUpdateRate + rand.nextGaussian() * 0.4 < t) {
+ if (lastVisionPose != null) {
+ estimator.addVisionMeasurement(lastVisionPose, lastVisionUpdateTime);
+ }
+ var groundPose = groundtruthState.poseMeters;
+ lastVisionPose =
+ new Pose2d(
+ new Translation2d(
+ groundPose.getTranslation().getX() + rand.nextGaussian() * 0.1,
+ groundPose.getTranslation().getY() + rand.nextGaussian() * 0.1),
+ new Rotation2d(rand.nextGaussian() * 0.01).plus(groundPose.getRotation()));
+ lastVisionUpdateTime = t;
+ }
+
+ input.leftMetersPerSecond += rand.nextGaussian() * 0.01;
+ input.rightMetersPerSecond += rand.nextGaussian() * 0.01;
+
+ distanceLeft += input.leftMetersPerSecond * dt;
+ distanceRight += input.rightMetersPerSecond * dt;
+
+ var rotNoise = new Rotation2d(rand.nextGaussian() * 0.001);
+ var xHat =
+ estimator.updateWithTime(
+ t,
+ groundtruthState.poseMeters.getRotation().plus(rotNoise),
+ input,
+ distanceLeft,
+ distanceRight);
+
+ double error =
+ groundtruthState.poseMeters.getTranslation().getDistance(xHat.getTranslation());
+ if (error > maxError) {
+ maxError = error;
+ }
+ errorSum += error;
+
+ t += dt;
+ }
+
+ assertEquals(0.0, errorSum / (traj.getTotalTimeSeconds() / dt), 0.035, "Incorrect mean error");
+ assertEquals(0.0, maxError, 0.055, "Incorrect max error");
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java
new file mode 100644
index 0000000..18b095b
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java
@@ -0,0 +1,189 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.math.numbers.N5;
+import edu.wpi.first.math.system.NumericalIntegration;
+import edu.wpi.first.math.system.NumericalJacobian;
+import edu.wpi.first.math.system.plant.DCMotor;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+@SuppressWarnings("ParameterName")
+public class ExtendedKalmanFilterTest {
+ public static Matrix<N5, N1> getDynamics(final Matrix<N5, N1> x, final Matrix<N2, N1> u) {
+ final var motors = DCMotor.getCIM(2);
+
+ final var gr = 7.08; // Gear ratio
+ final var rb = 0.8382 / 2.0; // Wheelbase radius (track width)
+ final var r = 0.0746125; // Wheel radius
+ final var m = 63.503; // Robot mass
+ final var J = 5.6; // Robot moment of inertia
+
+ final var C1 =
+ -Math.pow(gr, 2) * motors.KtNMPerAmp / (motors.KvRadPerSecPerVolt * motors.rOhms * r * r);
+ final var C2 = gr * motors.KtNMPerAmp / (motors.rOhms * r);
+ final var k1 = 1.0 / m + rb * rb / J;
+ final var k2 = 1.0 / m - rb * rb / J;
+
+ final var vl = x.get(3, 0);
+ final var vr = x.get(4, 0);
+ final var Vl = u.get(0, 0);
+ final var Vr = u.get(1, 0);
+
+ final Matrix<N5, N1> result = new Matrix<>(Nat.N5(), Nat.N1());
+ final var v = 0.5 * (vl + vr);
+ result.set(0, 0, v * Math.cos(x.get(2, 0)));
+ result.set(1, 0, v * Math.sin(x.get(2, 0)));
+ result.set(2, 0, (vr - vl) / (2.0 * rb));
+ result.set(3, 0, k1 * ((C1 * vl) + (C2 * Vl)) + k2 * ((C1 * vr) + (C2 * Vr)));
+ result.set(4, 0, k2 * ((C1 * vl) + (C2 * Vl)) + k1 * ((C1 * vr) + (C2 * Vr)));
+ return result;
+ }
+
+ public static Matrix<N3, N1> getLocalMeasurementModel(Matrix<N5, N1> x, Matrix<N2, N1> u) {
+ return VecBuilder.fill(x.get(2, 0), x.get(3, 0), x.get(4, 0));
+ }
+
+ public static Matrix<N5, N1> getGlobalMeasurementModel(Matrix<N5, N1> x, Matrix<N2, N1> u) {
+ return VecBuilder.fill(x.get(0, 0), x.get(1, 0), x.get(2, 0), x.get(3, 0), x.get(4, 0));
+ }
+
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ public void testInit() {
+ double dtSeconds = 0.00505;
+
+ assertDoesNotThrow(
+ () -> {
+ ExtendedKalmanFilter<N5, N2, N3> observer =
+ new ExtendedKalmanFilter<>(
+ Nat.N5(),
+ Nat.N2(),
+ Nat.N3(),
+ ExtendedKalmanFilterTest::getDynamics,
+ ExtendedKalmanFilterTest::getLocalMeasurementModel,
+ VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0),
+ VecBuilder.fill(0.0001, 0.01, 0.01),
+ dtSeconds);
+
+ Matrix<N2, N1> u = VecBuilder.fill(12.0, 12.0);
+ observer.predict(u, dtSeconds);
+
+ var localY = getLocalMeasurementModel(observer.getXhat(), u);
+ observer.correct(u, localY);
+
+ var globalY = getGlobalMeasurementModel(observer.getXhat(), u);
+ var R = StateSpaceUtil.makeCostMatrix(VecBuilder.fill(0.01, 0.01, 0.0001, 0.5, 0.5));
+ observer.correct(
+ Nat.N5(), u, globalY, ExtendedKalmanFilterTest::getGlobalMeasurementModel, R);
+ });
+ }
+
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ public void testConvergence() {
+ double dtSeconds = 0.00505;
+ double rbMeters = 0.8382 / 2.0; // Robot radius
+
+ ExtendedKalmanFilter<N5, N2, N3> observer =
+ new ExtendedKalmanFilter<>(
+ Nat.N5(),
+ Nat.N2(),
+ Nat.N3(),
+ ExtendedKalmanFilterTest::getDynamics,
+ ExtendedKalmanFilterTest::getLocalMeasurementModel,
+ VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0),
+ VecBuilder.fill(0.001, 0.01, 0.01),
+ dtSeconds);
+
+ List<Pose2d> waypoints =
+ Arrays.asList(
+ new Pose2d(2.75, 22.521, new Rotation2d()),
+ new Pose2d(24.73, 19.68, Rotation2d.fromDegrees(5.846)));
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(waypoints, new TrajectoryConfig(8.8, 0.1));
+
+ Matrix<N5, N1> r = new Matrix<>(Nat.N5(), Nat.N1());
+
+ Matrix<N5, N1> nextR = new Matrix<>(Nat.N5(), Nat.N1());
+ Matrix<N2, N1> u = new Matrix<>(Nat.N2(), Nat.N1());
+
+ var B =
+ NumericalJacobian.numericalJacobianU(
+ Nat.N5(),
+ Nat.N2(),
+ ExtendedKalmanFilterTest::getDynamics,
+ new Matrix<>(Nat.N5(), Nat.N1()),
+ u);
+
+ observer.setXhat(
+ VecBuilder.fill(
+ trajectory.getInitialPose().getTranslation().getX(),
+ trajectory.getInitialPose().getTranslation().getY(),
+ trajectory.getInitialPose().getRotation().getRadians(),
+ 0.0,
+ 0.0));
+
+ var groundTruthX = observer.getXhat();
+
+ double totalTime = trajectory.getTotalTimeSeconds();
+ for (int i = 0; i < (totalTime / dtSeconds); i++) {
+ var ref = trajectory.sample(dtSeconds * i);
+ double vl = ref.velocityMetersPerSecond * (1 - (ref.curvatureRadPerMeter * rbMeters));
+ double vr = ref.velocityMetersPerSecond * (1 + (ref.curvatureRadPerMeter * rbMeters));
+
+ nextR.set(0, 0, ref.poseMeters.getTranslation().getX());
+ nextR.set(1, 0, ref.poseMeters.getTranslation().getY());
+ nextR.set(2, 0, ref.poseMeters.getRotation().getRadians());
+ nextR.set(3, 0, vl);
+ nextR.set(4, 0, vr);
+
+ var localY = getLocalMeasurementModel(groundTruthX, u);
+ var whiteNoiseStdDevs = VecBuilder.fill(0.0001, 0.5, 0.5);
+ observer.correct(u, localY.plus(StateSpaceUtil.makeWhiteNoiseVector(whiteNoiseStdDevs)));
+
+ Matrix<N5, N1> rdot = nextR.minus(r).div(dtSeconds);
+ u = new Matrix<>(B.solve(rdot.minus(getDynamics(r, new Matrix<>(Nat.N2(), Nat.N1())))));
+
+ observer.predict(u, dtSeconds);
+
+ groundTruthX =
+ NumericalIntegration.rk4(
+ ExtendedKalmanFilterTest::getDynamics, groundTruthX, u, dtSeconds);
+
+ r = nextR;
+ }
+
+ var localY = getLocalMeasurementModel(observer.getXhat(), u);
+ observer.correct(u, localY);
+
+ var globalY = getGlobalMeasurementModel(observer.getXhat(), u);
+ var R = StateSpaceUtil.makeCostMatrix(VecBuilder.fill(0.01, 0.01, 0.0001, 0.5, 0.5));
+ observer.correct(Nat.N5(), u, globalY, ExtendedKalmanFilterTest::getGlobalMeasurementModel, R);
+
+ var finalPosition = trajectory.sample(trajectory.getTotalTimeSeconds());
+ assertEquals(finalPosition.poseMeters.getTranslation().getX(), observer.getXhat(0), 1.0);
+ assertEquals(finalPosition.poseMeters.getTranslation().getY(), observer.getXhat(1), 1.0);
+ assertEquals(finalPosition.poseMeters.getRotation().getRadians(), observer.getXhat(2), 1.0);
+ assertEquals(0.0, observer.getXhat(3), 1.0);
+ assertEquals(0.0, observer.getXhat(4), 1.0);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/KalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/KalmanFilterTest.java
new file mode 100644
index 0000000..0a18497
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/KalmanFilterTest.java
@@ -0,0 +1,191 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.numbers.N3;
+import edu.wpi.first.math.numbers.N6;
+import edu.wpi.first.math.system.LinearSystem;
+import edu.wpi.first.math.system.plant.DCMotor;
+import edu.wpi.first.math.system.plant.LinearSystemId;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.List;
+import java.util.Random;
+import org.junit.jupiter.api.Test;
+
+public class KalmanFilterTest {
+ private static LinearSystem<N2, N1, N1> elevatorPlant;
+
+ private static final double kDt = 0.00505;
+
+ static {
+ createElevator();
+ }
+
+ @SuppressWarnings("LocalVariableName")
+ public static void createElevator() {
+ var motors = DCMotor.getVex775Pro(2);
+
+ var m = 5.0;
+ var r = 0.0181864;
+ var G = 1.0;
+ elevatorPlant = LinearSystemId.createElevatorSystem(motors, m, r, G);
+ }
+
+ // A swerve drive system where the states are [x, y, theta, vx, vy, vTheta]ᵀ,
+ // Y is [x, y, theta]ᵀ and u is [ax, ay, alpha}ᵀ
+ LinearSystem<N6, N3, N3> m_swerveObserverSystem =
+ new LinearSystem<>(
+ Matrix.mat(Nat.N6(), Nat.N6())
+ .fill( // A
+ 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0),
+ Matrix.mat(Nat.N6(), Nat.N3())
+ .fill( // B
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1),
+ Matrix.mat(Nat.N3(), Nat.N6())
+ .fill( // C
+ 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0),
+ new Matrix<>(Nat.N3(), Nat.N3())); // D
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testElevatorKalmanFilter() {
+ var Q = VecBuilder.fill(0.05, 1.0);
+ var R = VecBuilder.fill(0.0001);
+
+ assertDoesNotThrow(() -> new KalmanFilter<>(Nat.N2(), Nat.N1(), elevatorPlant, Q, R, kDt));
+ }
+
+ @Test
+ public void testSwerveKFStationary() {
+ var random = new Random();
+
+ var filter =
+ new KalmanFilter<>(
+ Nat.N6(),
+ Nat.N3(),
+ m_swerveObserverSystem,
+ VecBuilder.fill(0.1, 0.1, 0.1, 0.1, 0.1, 0.1), // state
+ // weights
+ VecBuilder.fill(2, 2, 2), // measurement weights
+ 0.020);
+
+ Matrix<N3, N1> measurement;
+ for (int i = 0; i < 100; i++) {
+ // the robot is at [0, 0, 0] so we just park here
+ measurement =
+ VecBuilder.fill(random.nextGaussian(), random.nextGaussian(), random.nextGaussian());
+ filter.correct(VecBuilder.fill(0.0, 0.0, 0.0), measurement);
+
+ // we continue to not accelerate
+ filter.predict(VecBuilder.fill(0.0, 0.0, 0.0), 0.020);
+ }
+
+ assertEquals(0.0, filter.getXhat(0), 0.3);
+ assertEquals(0.0, filter.getXhat(0), 0.3);
+ }
+
+ @Test
+ public void testSwerveKFMovingWithoutAccelerating() {
+ var random = new Random();
+
+ var filter =
+ new KalmanFilter<>(
+ Nat.N6(),
+ Nat.N3(),
+ m_swerveObserverSystem,
+ VecBuilder.fill(0.1, 0.1, 0.1, 0.1, 0.1, 0.1), // state
+ // weights
+ VecBuilder.fill(4, 4, 4), // measurement weights
+ 0.020);
+
+ // we set the velocity of the robot so that it's moving forward slowly
+ filter.setXhat(0, 0.5);
+ filter.setXhat(1, 0.5);
+
+ for (int i = 0; i < 300; i++) {
+ // the robot is at [0, 0, 0] so we just park here
+ var measurement =
+ VecBuilder.fill(
+ random.nextGaussian() / 10d,
+ random.nextGaussian() / 10d,
+ random.nextGaussian() / 4d // std dev of [1, 1, 1]
+ );
+
+ filter.correct(VecBuilder.fill(0.0, 0.0, 0.0), measurement);
+
+ // we continue to not accelerate
+ filter.predict(VecBuilder.fill(0.0, 0.0, 0.0), 0.020);
+ }
+
+ assertEquals(0.0, filter.getXhat(0), 0.2);
+ assertEquals(0.0, filter.getXhat(1), 0.2);
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testSwerveKFMovingOverTrajectory() {
+ var random = new Random();
+
+ var filter =
+ new KalmanFilter<>(
+ Nat.N6(),
+ Nat.N3(),
+ m_swerveObserverSystem,
+ VecBuilder.fill(0.1, 0.1, 0.1, 0.1, 0.1, 0.1), // state
+ // weights
+ VecBuilder.fill(4, 4, 4), // measurement weights
+ 0.020);
+
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(
+ List.of(new Pose2d(0, 0, new Rotation2d()), new Pose2d(5, 5, new Rotation2d())),
+ new TrajectoryConfig(2, 2));
+ var time = 0.0;
+ var lastVelocity = VecBuilder.fill(0.0, 0.0, 0.0);
+
+ while (time <= trajectory.getTotalTimeSeconds()) {
+ var sample = trajectory.sample(time);
+ var measurement =
+ VecBuilder.fill(
+ sample.poseMeters.getTranslation().getX() + random.nextGaussian() / 5d,
+ sample.poseMeters.getTranslation().getY() + random.nextGaussian() / 5d,
+ sample.poseMeters.getRotation().getRadians() + random.nextGaussian() / 3d);
+
+ var velocity =
+ VecBuilder.fill(
+ sample.velocityMetersPerSecond * sample.poseMeters.getRotation().getCos(),
+ sample.velocityMetersPerSecond * sample.poseMeters.getRotation().getSin(),
+ sample.curvatureRadPerMeter * sample.velocityMetersPerSecond);
+ var u = (velocity.minus(lastVelocity)).div(0.020);
+ lastVelocity = velocity;
+
+ filter.correct(u, measurement);
+ filter.predict(u, 0.020);
+
+ time += 0.020;
+ }
+
+ assertEquals(
+ trajectory.sample(trajectory.getTotalTimeSeconds()).poseMeters.getTranslation().getX(),
+ filter.getXhat(0),
+ 0.2);
+ assertEquals(
+ trajectory.sample(trajectory.getTotalTimeSeconds()).poseMeters.getTranslation().getY(),
+ filter.getXhat(1),
+ 0.2);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/MecanumDrivePoseEstimatorTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/MecanumDrivePoseEstimatorTest.java
new file mode 100644
index 0000000..38b0d20
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/MecanumDrivePoseEstimatorTest.java
@@ -0,0 +1,117 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.MecanumDriveKinematics;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.List;
+import java.util.Random;
+import org.junit.jupiter.api.Test;
+
+public class MecanumDrivePoseEstimatorTest {
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testAccuracy() {
+ var kinematics =
+ new MecanumDriveKinematics(
+ new Translation2d(1, 1), new Translation2d(1, -1),
+ new Translation2d(-1, -1), new Translation2d(-1, 1));
+
+ var estimator =
+ new MecanumDrivePoseEstimator(
+ new Rotation2d(),
+ new Pose2d(),
+ kinematics,
+ VecBuilder.fill(0.1, 0.1, 0.1),
+ VecBuilder.fill(0.05),
+ VecBuilder.fill(0.1, 0.1, 0.1));
+
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(
+ List.of(
+ new Pose2d(),
+ new Pose2d(20, 20, Rotation2d.fromDegrees(0)),
+ new Pose2d(10, 10, Rotation2d.fromDegrees(180)),
+ new Pose2d(30, 30, Rotation2d.fromDegrees(0)),
+ new Pose2d(20, 20, Rotation2d.fromDegrees(180)),
+ new Pose2d(10, 10, Rotation2d.fromDegrees(0))),
+ new TrajectoryConfig(0.5, 2));
+
+ var rand = new Random(5190);
+
+ final double dt = 0.02;
+ double t = 0.0;
+
+ final double visionUpdateRate = 0.1;
+ Pose2d lastVisionPose = null;
+ double lastVisionUpdateTime = Double.NEGATIVE_INFINITY;
+
+ double maxError = Double.NEGATIVE_INFINITY;
+ double errorSum = 0;
+ while (t <= trajectory.getTotalTimeSeconds()) {
+ var groundTruthState = trajectory.sample(t);
+
+ if (lastVisionUpdateTime + visionUpdateRate < t) {
+ if (lastVisionPose != null) {
+ estimator.addVisionMeasurement(lastVisionPose, lastVisionUpdateTime);
+ }
+
+ lastVisionPose =
+ new Pose2d(
+ new Translation2d(
+ groundTruthState.poseMeters.getTranslation().getX() + rand.nextGaussian() * 0.1,
+ groundTruthState.poseMeters.getTranslation().getY()
+ + rand.nextGaussian() * 0.1),
+ new Rotation2d(rand.nextGaussian() * 0.1)
+ .plus(groundTruthState.poseMeters.getRotation()));
+
+ lastVisionUpdateTime = t;
+ }
+
+ var wheelSpeeds =
+ kinematics.toWheelSpeeds(
+ new ChassisSpeeds(
+ groundTruthState.velocityMetersPerSecond,
+ 0,
+ groundTruthState.velocityMetersPerSecond
+ * groundTruthState.curvatureRadPerMeter));
+
+ wheelSpeeds.frontLeftMetersPerSecond += rand.nextGaussian() * 0.1;
+ wheelSpeeds.frontRightMetersPerSecond += rand.nextGaussian() * 0.1;
+ wheelSpeeds.rearLeftMetersPerSecond += rand.nextGaussian() * 0.1;
+ wheelSpeeds.rearRightMetersPerSecond += rand.nextGaussian() * 0.1;
+
+ var xHat =
+ estimator.updateWithTime(
+ t,
+ groundTruthState
+ .poseMeters
+ .getRotation()
+ .plus(new Rotation2d(rand.nextGaussian() * 0.05)),
+ wheelSpeeds);
+
+ double error =
+ groundTruthState.poseMeters.getTranslation().getDistance(xHat.getTranslation());
+ if (error > maxError) {
+ maxError = error;
+ }
+ errorSum += error;
+
+ t += dt;
+ }
+
+ assertEquals(
+ 0.0, errorSum / (trajectory.getTotalTimeSeconds() / dt), 0.25, "Incorrect mean error");
+ assertEquals(0.0, maxError, 0.42, "Incorrect max error");
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/MerweScaledSigmaPointsTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/MerweScaledSigmaPointsTest.java
new file mode 100644
index 0000000..b6e32fc
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/MerweScaledSigmaPointsTest.java
@@ -0,0 +1,43 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import org.junit.jupiter.api.Test;
+
+public class MerweScaledSigmaPointsTest {
+ @Test
+ public void testZeroMeanPoints() {
+ var merweScaledSigmaPoints = new MerweScaledSigmaPoints<>(Nat.N2());
+ var points =
+ merweScaledSigmaPoints.sigmaPoints(
+ VecBuilder.fill(0, 0), Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1));
+
+ assertTrue(
+ points.isEqual(
+ Matrix.mat(Nat.N2(), Nat.N5())
+ .fill(
+ 0.0, 0.00173205, 0.0, -0.00173205, 0.0, 0.0, 0.0, 0.00173205, 0.0, -0.00173205),
+ 1E-6));
+ }
+
+ @Test
+ public void testNonzeroMeanPoints() {
+ var merweScaledSigmaPoints = new MerweScaledSigmaPoints<>(Nat.N2());
+ var points =
+ merweScaledSigmaPoints.sigmaPoints(
+ VecBuilder.fill(1, 2), Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 10));
+
+ assertTrue(
+ points.isEqual(
+ Matrix.mat(Nat.N2(), Nat.N5())
+ .fill(1.0, 1.00173205, 1.0, 0.99826795, 1.0, 2.0, 2.0, 2.00547723, 2.0, 1.99452277),
+ 1E-6));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/SwerveDrivePoseEstimatorTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/SwerveDrivePoseEstimatorTest.java
new file mode 100644
index 0000000..607e8de
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/SwerveDrivePoseEstimatorTest.java
@@ -0,0 +1,116 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.SwerveDriveKinematics;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.List;
+import java.util.Random;
+import org.junit.jupiter.api.Test;
+
+public class SwerveDrivePoseEstimatorTest {
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testAccuracy() {
+ var kinematics =
+ new SwerveDriveKinematics(
+ new Translation2d(1, 1),
+ new Translation2d(1, -1),
+ new Translation2d(-1, -1),
+ new Translation2d(-1, 1));
+ var estimator =
+ new SwerveDrivePoseEstimator(
+ new Rotation2d(),
+ new Pose2d(),
+ kinematics,
+ VecBuilder.fill(0.1, 0.1, 0.1),
+ VecBuilder.fill(0.005),
+ VecBuilder.fill(0.1, 0.1, 0.1));
+
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(
+ List.of(
+ new Pose2d(0, 0, Rotation2d.fromDegrees(45)),
+ new Pose2d(3, 0, Rotation2d.fromDegrees(-90)),
+ new Pose2d(0, 0, Rotation2d.fromDegrees(135)),
+ new Pose2d(-3, 0, Rotation2d.fromDegrees(-90)),
+ new Pose2d(0, 0, Rotation2d.fromDegrees(45))),
+ new TrajectoryConfig(0.5, 2));
+
+ var rand = new Random(4915);
+
+ final double dt = 0.02;
+ double t = 0.0;
+
+ final double visionUpdateRate = 0.1;
+ Pose2d lastVisionPose = null;
+ double lastVisionUpdateTime = Double.NEGATIVE_INFINITY;
+
+ double maxError = Double.NEGATIVE_INFINITY;
+ double errorSum = 0;
+ while (t <= trajectory.getTotalTimeSeconds()) {
+ var groundTruthState = trajectory.sample(t);
+
+ if (lastVisionUpdateTime + visionUpdateRate < t) {
+ if (lastVisionPose != null) {
+ estimator.addVisionMeasurement(lastVisionPose, lastVisionUpdateTime);
+ }
+
+ lastVisionPose =
+ new Pose2d(
+ new Translation2d(
+ groundTruthState.poseMeters.getTranslation().getX() + rand.nextGaussian() * 0.1,
+ groundTruthState.poseMeters.getTranslation().getY()
+ + rand.nextGaussian() * 0.1),
+ new Rotation2d(rand.nextGaussian() * 0.1)
+ .plus(groundTruthState.poseMeters.getRotation()));
+
+ lastVisionUpdateTime = t;
+ }
+
+ var moduleStates =
+ kinematics.toSwerveModuleStates(
+ new ChassisSpeeds(
+ groundTruthState.velocityMetersPerSecond,
+ 0.0,
+ groundTruthState.velocityMetersPerSecond
+ * groundTruthState.curvatureRadPerMeter));
+ for (var moduleState : moduleStates) {
+ moduleState.angle = moduleState.angle.plus(new Rotation2d(rand.nextGaussian() * 0.005));
+ moduleState.speedMetersPerSecond += rand.nextGaussian() * 0.1;
+ }
+
+ var xHat =
+ estimator.updateWithTime(
+ t,
+ groundTruthState
+ .poseMeters
+ .getRotation()
+ .plus(new Rotation2d(rand.nextGaussian() * 0.05)),
+ moduleStates);
+
+ double error =
+ groundTruthState.poseMeters.getTranslation().getDistance(xHat.getTranslation());
+ if (error > maxError) {
+ maxError = error;
+ }
+ errorSum += error;
+
+ t += dt;
+ }
+
+ assertEquals(
+ 0.0, errorSum / (trajectory.getTotalTimeSeconds() / dt), 0.25, "Incorrect mean error");
+ assertEquals(0.0, maxError, 0.42, "Incorrect max error");
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/UnscentedKalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/UnscentedKalmanFilterTest.java
new file mode 100644
index 0000000..1264591
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/UnscentedKalmanFilterTest.java
@@ -0,0 +1,329 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.estimator;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.StateSpaceUtil;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.numbers.N1;
+import edu.wpi.first.math.numbers.N2;
+import edu.wpi.first.math.numbers.N4;
+import edu.wpi.first.math.numbers.N6;
+import edu.wpi.first.math.system.Discretization;
+import edu.wpi.first.math.system.NumericalIntegration;
+import edu.wpi.first.math.system.NumericalJacobian;
+import edu.wpi.first.math.system.plant.DCMotor;
+import edu.wpi.first.math.system.plant.LinearSystemId;
+import edu.wpi.first.math.trajectory.TrajectoryConfig;
+import edu.wpi.first.math.trajectory.TrajectoryGenerator;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+public class UnscentedKalmanFilterTest {
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ public static Matrix<N6, N1> getDynamics(Matrix<N6, N1> x, Matrix<N2, N1> u) {
+ var motors = DCMotor.getCIM(2);
+
+ var gHigh = 7.08;
+ var rb = 0.8382 / 2.0;
+ var r = 0.0746125;
+ var m = 63.503;
+ var J = 5.6;
+
+ var C1 =
+ -Math.pow(gHigh, 2)
+ * motors.KtNMPerAmp
+ / (motors.KvRadPerSecPerVolt * motors.rOhms * r * r);
+ var C2 = gHigh * motors.KtNMPerAmp / (motors.rOhms * r);
+
+ var c = x.get(2, 0);
+ var s = x.get(3, 0);
+ var vl = x.get(4, 0);
+ var vr = x.get(5, 0);
+
+ var Vl = u.get(0, 0);
+ var Vr = u.get(1, 0);
+
+ var k1 = 1.0 / m + rb * rb / J;
+ var k2 = 1.0 / m - rb * rb / J;
+
+ var xvel = (vl + vr) / 2;
+ var w = (vr - vl) / (2.0 * rb);
+
+ return VecBuilder.fill(
+ xvel * c,
+ xvel * s,
+ -s * w,
+ c * w,
+ k1 * ((C1 * vl) + (C2 * Vl)) + k2 * ((C1 * vr) + (C2 * Vr)),
+ k2 * ((C1 * vl) + (C2 * Vl)) + k1 * ((C1 * vr) + (C2 * Vr)));
+ }
+
+ @SuppressWarnings("ParameterName")
+ public static Matrix<N4, N1> getLocalMeasurementModel(Matrix<N6, N1> x, Matrix<N2, N1> u) {
+ return VecBuilder.fill(x.get(2, 0), x.get(3, 0), x.get(4, 0), x.get(5, 0));
+ }
+
+ @SuppressWarnings("ParameterName")
+ public static Matrix<N6, N1> getGlobalMeasurementModel(Matrix<N6, N1> x, Matrix<N2, N1> u) {
+ return x.copy();
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ public void testInit() {
+ assertDoesNotThrow(
+ () -> {
+ UnscentedKalmanFilter<N6, N2, N4> observer =
+ new UnscentedKalmanFilter<>(
+ Nat.N6(),
+ Nat.N4(),
+ UnscentedKalmanFilterTest::getDynamics,
+ UnscentedKalmanFilterTest::getLocalMeasurementModel,
+ VecBuilder.fill(0.5, 0.5, 0.7, 0.7, 1.0, 1.0),
+ VecBuilder.fill(0.001, 0.001, 0.5, 0.5),
+ 0.00505);
+
+ var u = VecBuilder.fill(12.0, 12.0);
+ observer.predict(u, 0.00505);
+
+ var localY = getLocalMeasurementModel(observer.getXhat(), u);
+ observer.correct(u, localY);
+ });
+ }
+
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ public void testConvergence() {
+ double dtSeconds = 0.00505;
+ double rbMeters = 0.8382 / 2.0; // Robot radius
+
+ UnscentedKalmanFilter<N6, N2, N4> observer =
+ new UnscentedKalmanFilter<>(
+ Nat.N6(),
+ Nat.N4(),
+ UnscentedKalmanFilterTest::getDynamics,
+ UnscentedKalmanFilterTest::getLocalMeasurementModel,
+ VecBuilder.fill(0.5, 0.5, 0.7, 0.7, 1.0, 1.0),
+ VecBuilder.fill(0.001, 0.001, 0.5, 0.5),
+ dtSeconds);
+
+ List<Pose2d> waypoints =
+ Arrays.asList(
+ new Pose2d(2.75, 22.521, new Rotation2d()),
+ new Pose2d(24.73, 19.68, Rotation2d.fromDegrees(5.846)));
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(waypoints, new TrajectoryConfig(8.8, 0.1));
+
+ Matrix<N6, N1> nextR;
+ Matrix<N2, N1> u = new Matrix<>(Nat.N2(), Nat.N1());
+
+ var B =
+ NumericalJacobian.numericalJacobianU(
+ Nat.N6(),
+ Nat.N2(),
+ UnscentedKalmanFilterTest::getDynamics,
+ new Matrix<>(Nat.N6(), Nat.N1()),
+ u);
+
+ observer.setXhat(VecBuilder.fill(2.75, 22.521, 1.0, 0.0, 0.0, 0.0)); // TODO not hard code this
+
+ var ref = trajectory.sample(0.0);
+
+ Matrix<N6, N1> r =
+ VecBuilder.fill(
+ ref.poseMeters.getTranslation().getX(),
+ ref.poseMeters.getTranslation().getY(),
+ ref.poseMeters.getRotation().getCos(),
+ ref.poseMeters.getRotation().getSin(),
+ ref.velocityMetersPerSecond * (1 - (ref.curvatureRadPerMeter * rbMeters)),
+ ref.velocityMetersPerSecond * (1 + (ref.curvatureRadPerMeter * rbMeters)));
+ nextR = r.copy();
+
+ var trueXhat = observer.getXhat();
+
+ double totalTime = trajectory.getTotalTimeSeconds();
+ for (int i = 0; i < (totalTime / dtSeconds); i++) {
+ ref = trajectory.sample(dtSeconds * i);
+ double vl = ref.velocityMetersPerSecond * (1 - (ref.curvatureRadPerMeter * rbMeters));
+ double vr = ref.velocityMetersPerSecond * (1 + (ref.curvatureRadPerMeter * rbMeters));
+
+ nextR.set(0, 0, ref.poseMeters.getTranslation().getX());
+ nextR.set(1, 0, ref.poseMeters.getTranslation().getY());
+ nextR.set(2, 0, ref.poseMeters.getRotation().getCos());
+ nextR.set(3, 0, ref.poseMeters.getRotation().getSin());
+ nextR.set(4, 0, vl);
+ nextR.set(5, 0, vr);
+
+ Matrix<N4, N1> localY = getLocalMeasurementModel(trueXhat, new Matrix<>(Nat.N2(), Nat.N1()));
+ var noiseStdDev = VecBuilder.fill(0.001, 0.001, 0.5, 0.5);
+
+ observer.correct(u, localY.plus(StateSpaceUtil.makeWhiteNoiseVector(noiseStdDev)));
+
+ var rdot = nextR.minus(r).div(dtSeconds);
+ u = new Matrix<>(B.solve(rdot.minus(getDynamics(r, new Matrix<>(Nat.N2(), Nat.N1())))));
+
+ r = nextR;
+ observer.predict(u, dtSeconds);
+ trueXhat =
+ NumericalIntegration.rk4(UnscentedKalmanFilterTest::getDynamics, trueXhat, u, dtSeconds);
+ }
+
+ var localY = getLocalMeasurementModel(trueXhat, u);
+ observer.correct(u, localY);
+
+ var globalY = getGlobalMeasurementModel(trueXhat, u);
+ var R = StateSpaceUtil.makeCostMatrix(VecBuilder.fill(0.01, 0.01, 0.0001, 0.0001, 0.5, 0.5));
+ observer.correct(
+ Nat.N6(),
+ u,
+ globalY,
+ UnscentedKalmanFilterTest::getGlobalMeasurementModel,
+ R,
+ (sigmas, weights) -> sigmas.times(Matrix.changeBoundsUnchecked(weights)),
+ Matrix::minus,
+ Matrix::minus,
+ Matrix::plus);
+
+ final var finalPosition = trajectory.sample(trajectory.getTotalTimeSeconds());
+
+ assertEquals(finalPosition.poseMeters.getTranslation().getX(), observer.getXhat(0), 0.25);
+ assertEquals(finalPosition.poseMeters.getTranslation().getY(), observer.getXhat(1), 0.25);
+ assertEquals(finalPosition.poseMeters.getRotation().getRadians(), observer.getXhat(2), 1.0);
+ assertEquals(0.0, observer.getXhat(3), 1.0);
+ assertEquals(0.0, observer.getXhat(4), 1.0);
+ }
+
+ @Test
+ @SuppressWarnings({"LocalVariableName", "ParameterName"})
+ public void testLinearUKF() {
+ var dt = 0.020;
+ var plant = LinearSystemId.identifyVelocitySystem(0.02, 0.006);
+ var observer =
+ new UnscentedKalmanFilter<>(
+ Nat.N1(),
+ Nat.N1(),
+ (x, u) -> plant.getA().times(x).plus(plant.getB().times(u)),
+ plant::calculateY,
+ VecBuilder.fill(0.05),
+ VecBuilder.fill(1.0),
+ dt);
+
+ var discABPair = Discretization.discretizeAB(plant.getA(), plant.getB(), dt);
+ var discA = discABPair.getFirst();
+ var discB = discABPair.getSecond();
+
+ Matrix<N1, N1> ref = VecBuilder.fill(100);
+ Matrix<N1, N1> u = VecBuilder.fill(0);
+
+ for (int i = 0; i < (2.0 / dt); i++) {
+ observer.predict(u, dt);
+
+ u = discB.solve(ref.minus(discA.times(ref)));
+ }
+
+ assertEquals(ref.get(0, 0), observer.getXhat(0), 5);
+ }
+
+ @Test
+ public void testUnscentedTransform() {
+ // From FilterPy
+ var ret =
+ UnscentedKalmanFilter.unscentedTransform(
+ Nat.N4(),
+ Nat.N4(),
+ Matrix.mat(Nat.N4(), Nat.N9())
+ .fill(
+ -0.9,
+ -0.822540333075852,
+ -0.8922540333075852,
+ -0.9,
+ -0.9,
+ -0.9774596669241481,
+ -0.9077459666924148,
+ -0.9,
+ -0.9,
+ 1.0,
+ 1.0,
+ 1.077459666924148,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.9225403330758519,
+ 1.0,
+ 1.0,
+ -0.9,
+ -0.9,
+ -0.9,
+ -0.822540333075852,
+ -0.8922540333075852,
+ -0.9,
+ -0.9,
+ -0.9774596669241481,
+ -0.9077459666924148,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.077459666924148,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.9225403330758519),
+ VecBuilder.fill(
+ -132.33333333,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667),
+ VecBuilder.fill(
+ -129.34333333,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667,
+ 16.66666667),
+ (sigmas, weights) -> sigmas.times(Matrix.changeBoundsUnchecked(weights)),
+ Matrix::minus);
+
+ assertTrue(VecBuilder.fill(-0.9, 1, -0.9, 1).isEqual(ret.getFirst(), 1E-5));
+
+ assertTrue(
+ Matrix.mat(Nat.N4(), Nat.N4())
+ .fill(
+ 2.02000002e-01,
+ 2.00000500e-02,
+ -2.69044710e-29,
+ -4.59511477e-29,
+ 2.00000500e-02,
+ 2.00001000e-01,
+ -2.98781068e-29,
+ -5.12759588e-29,
+ -2.73372625e-29,
+ -3.09882635e-29,
+ 2.02000002e-01,
+ 2.00000500e-02,
+ -4.67065917e-29,
+ -5.10705197e-29,
+ 2.00000500e-02,
+ 2.00001000e-01)
+ .isEqual(ret.getSecond(), 1E-5));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java
new file mode 100644
index 0000000..805129f
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java
@@ -0,0 +1,220 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.filter;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+import java.util.Random;
+import java.util.function.DoubleFunction;
+import java.util.stream.Stream;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+class LinearFilterTest {
+ private static final double kFilterStep = 0.005;
+ private static final double kFilterTime = 2.0;
+ private static final double kSinglePoleIIRTimeConstant = 0.015915;
+ private static final double kHighPassTimeConstant = 0.006631;
+ private static final int kMovAvgTaps = 6;
+
+ private static final double kSinglePoleIIRExpectedOutput = -3.2172003;
+ private static final double kHighPassExpectedOutput = 10.074717;
+ private static final double kMovAvgExpectedOutput = -10.191644;
+
+ private static double getData(double t) {
+ return 100.0 * Math.sin(2.0 * Math.PI * t) + 20.0 * Math.cos(50.0 * Math.PI * t);
+ }
+
+ private static double getPulseData(double t) {
+ if (Math.abs(t - 1.0) < 0.001) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+
+ @Test
+ void illegalTapNumberTest() {
+ assertThrows(IllegalArgumentException.class, () -> LinearFilter.movingAverage(0));
+ }
+
+ /** Test if the filter reduces the noise produced by a signal generator. */
+ @ParameterizedTest
+ @MethodSource("noiseFilterProvider")
+ void noiseReduceTest(final LinearFilter filter) {
+ double noiseGenError = 0.0;
+ double filterError = 0.0;
+
+ final Random gen = new Random();
+ final double kStdDev = 10.0;
+
+ for (double t = 0; t < kFilterTime; t += kFilterStep) {
+ final double theory = getData(t);
+ final double noise = gen.nextGaussian() * kStdDev;
+ filterError += Math.abs(filter.calculate(theory + noise) - theory);
+ noiseGenError += Math.abs(noise - theory);
+ }
+
+ assertTrue(
+ noiseGenError > filterError,
+ "Filter should have reduced noise accumulation from "
+ + noiseGenError
+ + " but failed. The filter error was "
+ + filterError);
+ }
+
+ static Stream<LinearFilter> noiseFilterProvider() {
+ return Stream.of(
+ LinearFilter.singlePoleIIR(kSinglePoleIIRTimeConstant, kFilterStep),
+ LinearFilter.movingAverage(kMovAvgTaps));
+ }
+
+ /** Test if the linear filters produce consistent output for a given data set. */
+ @ParameterizedTest
+ @MethodSource("outputFilterProvider")
+ void outputTest(
+ final LinearFilter filter, final DoubleFunction<Double> data, final double expectedOutput) {
+ double filterOutput = 0.0;
+ for (double t = 0.0; t < kFilterTime; t += kFilterStep) {
+ filterOutput = filter.calculate(data.apply(t));
+ }
+
+ assertEquals(expectedOutput, filterOutput, 5e-5, "Filter output was incorrect.");
+ }
+
+ static Stream<Arguments> outputFilterProvider() {
+ return Stream.of(
+ arguments(
+ LinearFilter.singlePoleIIR(kSinglePoleIIRTimeConstant, kFilterStep),
+ (DoubleFunction<Double>) LinearFilterTest::getData,
+ kSinglePoleIIRExpectedOutput),
+ arguments(
+ LinearFilter.highPass(kHighPassTimeConstant, kFilterStep),
+ (DoubleFunction<Double>) LinearFilterTest::getData,
+ kHighPassExpectedOutput),
+ arguments(
+ LinearFilter.movingAverage(kMovAvgTaps),
+ (DoubleFunction<Double>) LinearFilterTest::getData,
+ kMovAvgExpectedOutput),
+ arguments(
+ LinearFilter.movingAverage(kMovAvgTaps),
+ (DoubleFunction<Double>) LinearFilterTest::getPulseData,
+ 0.0));
+ }
+
+ /** Test backward finite difference. */
+ @Test
+ void backwardFiniteDifferenceTest() {
+ double h = 0.005;
+
+ assertResults(
+ 1,
+ 2,
+ // f(x) = x²
+ (double x) -> x * x,
+ // df/dx = 2x
+ (double x) -> 2.0 * x,
+ h,
+ -20.0,
+ 20.0);
+
+ assertResults(
+ 1,
+ 2,
+ // f(x) = sin(x)
+ (double x) -> Math.sin(x),
+ // df/dx = cos(x)
+ (double x) -> Math.cos(x),
+ h,
+ -20.0,
+ 20.0);
+
+ assertResults(
+ 1,
+ 2,
+ // f(x) = ln(x)
+ (double x) -> Math.log(x),
+ // df/dx = 1 / x
+ (double x) -> 1.0 / x,
+ h,
+ 1.0,
+ 20.0);
+
+ assertResults(
+ 2,
+ 4,
+ // f(x) = x²
+ (double x) -> x * x,
+ // d²f/dx² = 2
+ (double x) -> 2.0,
+ h,
+ -20.0,
+ 20.0);
+
+ assertResults(
+ 2,
+ 4,
+ // f(x) = sin(x)
+ (double x) -> Math.sin(x),
+ // d²f/dx² = -sin(x)
+ (double x) -> -Math.sin(x),
+ h,
+ -20.0,
+ 20.0);
+
+ assertResults(
+ 2,
+ 4,
+ // f(x) = ln(x)
+ (double x) -> Math.log(x),
+ // d²f/dx² = -1 / x²
+ (double x) -> -1.0 / (x * x),
+ h,
+ 1.0,
+ 20.0);
+ }
+
+ /**
+ * Helper for checking results of backward finite difference.
+ *
+ * @param derivative The order of the derivative.
+ * @param samples The number of sample points.
+ * @param f Function of which to take derivative.
+ * @param dfdx Derivative of f.
+ * @param h Sample period in seconds.
+ * @param min Minimum of f's domain to test.
+ * @param max Maximum of f's domain to test.
+ */
+ void assertResults(
+ int derivative,
+ int samples,
+ DoubleFunction<Double> f,
+ DoubleFunction<Double> dfdx,
+ double h,
+ double min,
+ double max) {
+ var filter = LinearFilter.backwardFiniteDifference(derivative, samples, h);
+
+ for (int i = (int) (min / h); i < (int) (max / h); ++i) {
+ // Let filter initialize
+ if (i < (int) (min / h) + samples) {
+ filter.calculate(f.apply(i * h));
+ continue;
+ }
+
+ // The order of accuracy is O(h^(N - d)) where N is number of stencil
+ // points and d is order of derivative
+ assertEquals(
+ dfdx.apply(i * h),
+ filter.calculate(f.apply(i * h)),
+ 10.0 * Math.pow(h, samples - derivative));
+ }
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/MedianFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/filter/MedianFilterTest.java
similarity index 69%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/MedianFilterTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/filter/MedianFilterTest.java
index b2b596c..06b3d01 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/MedianFilterTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/filter/MedianFilterTest.java
@@ -1,16 +1,13 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj;
-
-import org.junit.jupiter.api.Test;
+package edu.wpi.first.math.filter;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.Test;
+
public class MedianFilterTest {
@Test
void medianFilterNotFullTestEven() {
diff --git a/wpimath/src/test/java/edu/wpi/first/math/filter/SlewRateLimiterTest.java b/wpimath/src/test/java/edu/wpi/first/math/filter/SlewRateLimiterTest.java
new file mode 100644
index 0000000..9c1f2f3
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/filter/SlewRateLimiterTest.java
@@ -0,0 +1,35 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.filter;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.util.WPIUtilJNI;
+import org.junit.jupiter.api.Test;
+
+public class SlewRateLimiterTest {
+ @Test
+ void slewRateLimitTest() {
+ WPIUtilJNI.enableMockTime();
+
+ var limiter = new SlewRateLimiter(1);
+ WPIUtilJNI.setMockTime(1000000L);
+ assertTrue(limiter.calculate(2) < 2);
+
+ WPIUtilJNI.setMockTime(0L);
+ }
+
+ @Test
+ void slewRateNoLimitTest() {
+ WPIUtilJNI.enableMockTime();
+
+ var limiter = new SlewRateLimiter(1);
+ WPIUtilJNI.setMockTime(1000000L);
+ assertEquals(limiter.calculate(0.5), 0.5);
+
+ WPIUtilJNI.setMockTime(0L);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Pose2dTest.java b/wpimath/src/test/java/edu/wpi/first/math/geometry/Pose2dTest.java
similarity index 74%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Pose2dTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/geometry/Pose2dTest.java
index 14bad1f..b6e66af 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Pose2dTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/geometry/Pose2dTest.java
@@ -1,34 +1,29 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
-
-import org.junit.jupiter.api.Test;
+package edu.wpi.first.math.geometry;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import org.junit.jupiter.api.Test;
+
class Pose2dTest {
private static final double kEpsilon = 1E-9;
@Test
void testTransformBy() {
var initial = new Pose2d(new Translation2d(1.0, 2.0), Rotation2d.fromDegrees(45.0));
- var transformation = new Transform2d(new Translation2d(5.0, 0.0),
- Rotation2d.fromDegrees(5.0));
+ var transformation = new Transform2d(new Translation2d(5.0, 0.0), Rotation2d.fromDegrees(5.0));
var transformed = initial.plus(transformation);
assertAll(
() -> assertEquals(transformed.getX(), 1 + 5.0 / Math.sqrt(2.0), kEpsilon),
() -> assertEquals(transformed.getY(), 2 + 5.0 / Math.sqrt(2.0), kEpsilon),
- () -> assertEquals(transformed.getRotation().getDegrees(), 50.0, kEpsilon)
- );
+ () -> assertEquals(transformed.getRotation().getDegrees(), 50.0, kEpsilon));
}
@Test
@@ -39,11 +34,9 @@
var finalRelativeToInitial = last.relativeTo(initial);
assertAll(
- () -> assertEquals(finalRelativeToInitial.getX(), 5.0 * Math.sqrt(2.0),
- kEpsilon),
+ () -> assertEquals(finalRelativeToInitial.getX(), 5.0 * Math.sqrt(2.0), kEpsilon),
() -> assertEquals(finalRelativeToInitial.getY(), 0.0, kEpsilon),
- () -> assertEquals(finalRelativeToInitial.getRotation().getDegrees(), 0.0, kEpsilon)
- );
+ () -> assertEquals(finalRelativeToInitial.getRotation().getDegrees(), 0.0, kEpsilon));
}
@Test
@@ -69,7 +62,6 @@
assertAll(
() -> assertEquals(transform.getX(), 5.0 * Math.sqrt(2.0), kEpsilon),
() -> assertEquals(transform.getY(), 0.0, kEpsilon),
- () -> assertEquals(transform.getRotation().getDegrees(), 0.0, kEpsilon)
- );
+ () -> assertEquals(transform.getRotation().getDegrees(), 0.0, kEpsilon));
}
}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/geometry/Rotation2dTest.java b/wpimath/src/test/java/edu/wpi/first/math/geometry/Rotation2dTest.java
new file mode 100644
index 0000000..cb3f0f3
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/geometry/Rotation2dTest.java
@@ -0,0 +1,79 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.geometry;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+
+import org.junit.jupiter.api.Test;
+
+class Rotation2dTest {
+ private static final double kEpsilon = 1E-9;
+
+ @Test
+ void testRadiansToDegrees() {
+ var rot1 = new Rotation2d(Math.PI / 3);
+ var rot2 = new Rotation2d(Math.PI / 4);
+
+ assertAll(
+ () -> assertEquals(rot1.getDegrees(), 60.0, kEpsilon),
+ () -> assertEquals(rot2.getDegrees(), 45.0, kEpsilon));
+ }
+
+ @Test
+ void testRadiansAndDegrees() {
+ var rot1 = Rotation2d.fromDegrees(45.0);
+ var rot2 = Rotation2d.fromDegrees(30.0);
+
+ assertAll(
+ () -> assertEquals(rot1.getRadians(), Math.PI / 4, kEpsilon),
+ () -> assertEquals(rot2.getRadians(), Math.PI / 6, kEpsilon));
+ }
+
+ @Test
+ void testRotateByFromZero() {
+ var zero = new Rotation2d();
+ var rotated = zero.rotateBy(Rotation2d.fromDegrees(90.0));
+
+ assertAll(
+ () -> assertEquals(rotated.getRadians(), Math.PI / 2.0, kEpsilon),
+ () -> assertEquals(rotated.getDegrees(), 90.0, kEpsilon));
+ }
+
+ @Test
+ void testRotateByNonZero() {
+ var rot = Rotation2d.fromDegrees(90.0);
+ rot = rot.plus(Rotation2d.fromDegrees(30.0));
+
+ assertEquals(rot.getDegrees(), 120.0, kEpsilon);
+ }
+
+ @Test
+ void testMinus() {
+ var rot1 = Rotation2d.fromDegrees(70.0);
+ var rot2 = Rotation2d.fromDegrees(30.0);
+
+ assertEquals(rot1.minus(rot2).getDegrees(), 40.0, kEpsilon);
+ }
+
+ @Test
+ void testEquality() {
+ var rot1 = Rotation2d.fromDegrees(43.0);
+ var rot2 = Rotation2d.fromDegrees(43.0);
+ assertEquals(rot1, rot2);
+
+ var rot3 = Rotation2d.fromDegrees(-180.0);
+ var rot4 = Rotation2d.fromDegrees(180.0);
+ assertEquals(rot3, rot4);
+ }
+
+ @Test
+ void testInequality() {
+ var rot1 = Rotation2d.fromDegrees(43.0);
+ var rot2 = Rotation2d.fromDegrees(43.5);
+ assertNotEquals(rot1, rot2);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/geometry/Transform2dTest.java b/wpimath/src/test/java/edu/wpi/first/math/geometry/Transform2dTest.java
new file mode 100644
index 0000000..7265e25
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/geometry/Transform2dTest.java
@@ -0,0 +1,51 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.geometry;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.jupiter.api.Test;
+
+class Transform2dTest {
+ private static final double kEpsilon = 1E-9;
+
+ @Test
+ void testInverse() {
+ var initial = new Pose2d(new Translation2d(1.0, 2.0), Rotation2d.fromDegrees(45.0));
+ var transform = new Transform2d(new Translation2d(5.0, 0.0), Rotation2d.fromDegrees(5.0));
+
+ var transformed = initial.plus(transform);
+ var untransformed = transformed.plus(transform.inverse());
+
+ assertAll(
+ () -> assertEquals(initial.getX(), untransformed.getX(), kEpsilon),
+ () -> assertEquals(initial.getY(), untransformed.getY(), kEpsilon),
+ () ->
+ assertEquals(
+ initial.getRotation().getDegrees(),
+ untransformed.getRotation().getDegrees(),
+ kEpsilon));
+ }
+
+ @Test
+ void testComposition() {
+ var initial = new Pose2d(new Translation2d(1.0, 2.0), Rotation2d.fromDegrees(45.0));
+ var transform1 = new Transform2d(new Translation2d(5.0, 0.0), Rotation2d.fromDegrees(5.0));
+ var transform2 = new Transform2d(new Translation2d(0.0, 2.0), Rotation2d.fromDegrees(5.0));
+
+ var transformedSeparate = initial.plus(transform1).plus(transform2);
+ var transformedCombined = initial.plus(transform1.plus(transform2));
+
+ assertAll(
+ () -> assertEquals(transformedSeparate.getX(), transformedCombined.getX(), kEpsilon),
+ () -> assertEquals(transformedSeparate.getY(), transformedCombined.getY(), kEpsilon),
+ () ->
+ assertEquals(
+ transformedSeparate.getRotation().getDegrees(),
+ transformedCombined.getRotation().getDegrees(),
+ kEpsilon));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Translation2dTest.java b/wpimath/src/test/java/edu/wpi/first/math/geometry/Translation2dTest.java
similarity index 72%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Translation2dTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/geometry/Translation2dTest.java
index d6844a5..2d8eeaa 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Translation2dTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/geometry/Translation2dTest.java
@@ -1,18 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
-
-import org.junit.jupiter.api.Test;
+package edu.wpi.first.math.geometry;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import org.junit.jupiter.api.Test;
+
class Translation2dTest {
private static final double kEpsilon = 1E-9;
@@ -25,8 +22,7 @@
assertAll(
() -> assertEquals(sum.getX(), 3.0, kEpsilon),
- () -> assertEquals(sum.getY(), 8.0, kEpsilon)
- );
+ () -> assertEquals(sum.getY(), 8.0, kEpsilon));
}
@Test
@@ -38,8 +34,7 @@
assertAll(
() -> assertEquals(difference.getX(), -1.0, kEpsilon),
- () -> assertEquals(difference.getY(), -2.0, kEpsilon)
- );
+ () -> assertEquals(difference.getY(), -2.0, kEpsilon));
}
@Test
@@ -49,8 +44,7 @@
assertAll(
() -> assertEquals(rotated.getX(), 0.0, kEpsilon),
- () -> assertEquals(rotated.getY(), 3.0, kEpsilon)
- );
+ () -> assertEquals(rotated.getY(), 3.0, kEpsilon));
}
@Test
@@ -60,8 +54,7 @@
assertAll(
() -> assertEquals(mult.getX(), 9.0, kEpsilon),
- () -> assertEquals(mult.getY(), 15.0, kEpsilon)
- );
+ () -> assertEquals(mult.getY(), 15.0, kEpsilon));
}
@Test
@@ -71,8 +64,7 @@
assertAll(
() -> assertEquals(div.getX(), 1.5, kEpsilon),
- () -> assertEquals(div.getY(), 2.5, kEpsilon)
- );
+ () -> assertEquals(div.getY(), 2.5, kEpsilon));
}
@Test
@@ -95,8 +87,7 @@
assertAll(
() -> assertEquals(inverted.getX(), 4.5, kEpsilon),
- () -> assertEquals(inverted.getY(), -7, kEpsilon)
- );
+ () -> assertEquals(inverted.getY(), -7, kEpsilon));
}
@Test
@@ -121,7 +112,6 @@
() -> assertEquals(one.getX(), 1.0, kEpsilon),
() -> assertEquals(one.getY(), 1.0, kEpsilon),
() -> assertEquals(two.getX(), 1.0, kEpsilon),
- () -> assertEquals(two.getY(), Math.sqrt(3), kEpsilon)
- );
+ () -> assertEquals(two.getY(), Math.sqrt(3), kEpsilon));
}
}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Twist2dTest.java b/wpimath/src/test/java/edu/wpi/first/math/geometry/Twist2dTest.java
similarity index 76%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Twist2dTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/geometry/Twist2dTest.java
index 18ea6d9..c13bb09 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Twist2dTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/geometry/Twist2dTest.java
@@ -1,18 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.geometry;
-
-import org.junit.jupiter.api.Test;
+package edu.wpi.first.math.geometry;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import org.junit.jupiter.api.Test;
+
class Twist2dTest {
private static final double kEpsilon = 1E-9;
@@ -24,8 +21,7 @@
assertAll(
() -> assertEquals(straightPose.getX(), 5.0, kEpsilon),
() -> assertEquals(straightPose.getY(), 0.0, kEpsilon),
- () -> assertEquals(straightPose.getRotation().getRadians(), 0.0, kEpsilon)
- );
+ () -> assertEquals(straightPose.getRotation().getRadians(), 0.0, kEpsilon));
}
@Test
@@ -36,8 +32,7 @@
assertAll(
() -> assertEquals(quarterCirclePose.getX(), 5.0, kEpsilon),
() -> assertEquals(quarterCirclePose.getY(), 5.0, kEpsilon),
- () -> assertEquals(quarterCirclePose.getRotation().getDegrees(), 90.0, kEpsilon)
- );
+ () -> assertEquals(quarterCirclePose.getRotation().getDegrees(), 90.0, kEpsilon));
}
@Test
@@ -48,8 +43,7 @@
assertAll(
() -> assertEquals(diagonalPose.getX(), 2.0, kEpsilon),
() -> assertEquals(diagonalPose.getY(), 2.0, kEpsilon),
- () -> assertEquals(diagonalPose.getRotation().getDegrees(), 0.0, kEpsilon)
- );
+ () -> assertEquals(diagonalPose.getRotation().getDegrees(), 0.0, kEpsilon));
}
@Test
@@ -75,7 +69,6 @@
assertAll(
() -> assertEquals(twist.dx, 5.0 / 2.0 * Math.PI, kEpsilon),
() -> assertEquals(twist.dy, 0.0, kEpsilon),
- () -> assertEquals(twist.dtheta, Math.PI / 2.0, kEpsilon)
- );
+ () -> assertEquals(twist.dtheta, Math.PI / 2.0, kEpsilon));
}
}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/kinematics/ChassisSpeedsTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/ChassisSpeedsTest.java
new file mode 100644
index 0000000..b9c3785
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/ChassisSpeedsTest.java
@@ -0,0 +1,26 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.geometry.Rotation2d;
+import org.junit.jupiter.api.Test;
+
+class ChassisSpeedsTest {
+ private static final double kEpsilon = 1E-9;
+
+ @Test
+ void testFieldRelativeConstruction() {
+ final var chassisSpeeds =
+ ChassisSpeeds.fromFieldRelativeSpeeds(1.0, 0.0, 0.5, Rotation2d.fromDegrees(-90.0));
+
+ assertAll(
+ () -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
+ () -> assertEquals(1.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
+ () -> assertEquals(0.5, chassisSpeeds.omegaRadiansPerSecond, kEpsilon));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveKinematicsTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/DifferentialDriveKinematicsTest.java
similarity index 76%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveKinematicsTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/kinematics/DifferentialDriveKinematicsTest.java
index 9d2ad4e..adee41f 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveKinematicsTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/DifferentialDriveKinematicsTest.java
@@ -1,21 +1,18 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
+package edu.wpi.first.math.kinematics;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.Test;
+
class DifferentialDriveKinematicsTest {
private static final double kEpsilon = 1E-9;
- private final DifferentialDriveKinematics m_kinematics
- = new DifferentialDriveKinematics(0.381 * 2);
+ private final DifferentialDriveKinematics m_kinematics =
+ new DifferentialDriveKinematics(0.381 * 2);
@Test
void testInverseKinematicsForZeros() {
@@ -24,8 +21,7 @@
assertAll(
() -> assertEquals(0.0, wheelSpeeds.leftMetersPerSecond, kEpsilon),
- () -> assertEquals(0.0, wheelSpeeds.rightMetersPerSecond, kEpsilon)
- );
+ () -> assertEquals(0.0, wheelSpeeds.rightMetersPerSecond, kEpsilon));
}
@Test
@@ -36,8 +32,7 @@
assertAll(
() -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
() -> assertEquals(0.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon)
- );
+ () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon));
}
@Test
@@ -47,8 +42,7 @@
assertAll(
() -> assertEquals(3.0, wheelSpeeds.leftMetersPerSecond, kEpsilon),
- () -> assertEquals(3.0, wheelSpeeds.rightMetersPerSecond, kEpsilon)
- );
+ () -> assertEquals(3.0, wheelSpeeds.rightMetersPerSecond, kEpsilon));
}
@Test
@@ -59,8 +53,7 @@
assertAll(
() -> assertEquals(3.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
() -> assertEquals(0.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon)
- );
+ () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon));
}
@Test
@@ -70,8 +63,7 @@
assertAll(
() -> assertEquals(-0.381 * Math.PI, wheelSpeeds.leftMetersPerSecond, kEpsilon),
- () -> assertEquals(+0.381 * Math.PI, wheelSpeeds.rightMetersPerSecond, kEpsilon)
- );
+ () -> assertEquals(+0.381 * Math.PI, wheelSpeeds.rightMetersPerSecond, kEpsilon));
}
@Test
@@ -82,7 +74,6 @@
assertAll(
() -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
() -> assertEquals(0.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(-Math.PI, chassisSpeeds.omegaRadiansPerSecond, kEpsilon)
- );
+ () -> assertEquals(-Math.PI, chassisSpeeds.omegaRadiansPerSecond, kEpsilon));
}
}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/kinematics/DifferentialDriveOdometryTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/DifferentialDriveOdometryTest.java
new file mode 100644
index 0000000..f85e8fb
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/DifferentialDriveOdometryTest.java
@@ -0,0 +1,29 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import org.junit.jupiter.api.Test;
+
+class DifferentialDriveOdometryTest {
+ private static final double kEpsilon = 1E-9;
+ private final DifferentialDriveOdometry m_odometry =
+ new DifferentialDriveOdometry(new Rotation2d());
+
+ @Test
+ void testOdometryWithEncoderDistances() {
+ m_odometry.resetPosition(new Pose2d(), Rotation2d.fromDegrees(45));
+ var pose = m_odometry.update(Rotation2d.fromDegrees(135.0), 0.0, 5 * Math.PI);
+
+ assertAll(
+ () -> assertEquals(pose.getX(), 5.0, kEpsilon),
+ () -> assertEquals(pose.getY(), 5.0, kEpsilon),
+ () -> assertEquals(pose.getRotation().getDegrees(), 90.0, kEpsilon));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/kinematics/MecanumDriveKinematicsTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/MecanumDriveKinematicsTest.java
new file mode 100644
index 0000000..d334679
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/MecanumDriveKinematicsTest.java
@@ -0,0 +1,175 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.geometry.Translation2d;
+import org.junit.jupiter.api.Test;
+
+class MecanumDriveKinematicsTest {
+ private static final double kEpsilon = 1E-9;
+
+ private final Translation2d m_fl = new Translation2d(12, 12);
+ private final Translation2d m_fr = new Translation2d(12, -12);
+ private final Translation2d m_bl = new Translation2d(-12, 12);
+ private final Translation2d m_br = new Translation2d(-12, -12);
+
+ private final MecanumDriveKinematics m_kinematics =
+ new MecanumDriveKinematics(m_fl, m_fr, m_bl, m_br);
+
+ @Test
+ void testStraightLineInverseKinematics() {
+ ChassisSpeeds speeds = new ChassisSpeeds(5, 0, 0);
+ var moduleStates = m_kinematics.toWheelSpeeds(speeds);
+
+ assertAll(
+ () -> assertEquals(5.0, moduleStates.frontLeftMetersPerSecond, 0.1),
+ () -> assertEquals(5.0, moduleStates.frontRightMetersPerSecond, 0.1),
+ () -> assertEquals(5.0, moduleStates.rearLeftMetersPerSecond, 0.1),
+ () -> assertEquals(5.0, moduleStates.rearRightMetersPerSecond, 0.1));
+ }
+
+ @Test
+ void testStraightLineForwardKinematicsKinematics() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(3.536, 3.536, 3.536, 3.536);
+ var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ assertAll(
+ () -> assertEquals(3.536, moduleStates.vxMetersPerSecond, 0.1),
+ () -> assertEquals(0, moduleStates.vyMetersPerSecond, 0.1),
+ () -> assertEquals(0, moduleStates.omegaRadiansPerSecond, 0.1));
+ }
+
+ @Test
+ void testStrafeInverseKinematics() {
+ ChassisSpeeds speeds = new ChassisSpeeds(0, 4, 0);
+ var moduleStates = m_kinematics.toWheelSpeeds(speeds);
+
+ assertAll(
+ () -> assertEquals(-4.0, moduleStates.frontLeftMetersPerSecond, 0.1),
+ () -> assertEquals(4.0, moduleStates.frontRightMetersPerSecond, 0.1),
+ () -> assertEquals(4.0, moduleStates.rearLeftMetersPerSecond, 0.1),
+ () -> assertEquals(-4.0, moduleStates.rearRightMetersPerSecond, 0.1));
+ }
+
+ @Test
+ void testStrafeForwardKinematicsKinematics() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(-2.828427, 2.828427, 2.828427, -2.828427);
+ var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ assertAll(
+ () -> assertEquals(0, moduleStates.vxMetersPerSecond, 0.1),
+ () -> assertEquals(2.8284, moduleStates.vyMetersPerSecond, 0.1),
+ () -> assertEquals(0, moduleStates.omegaRadiansPerSecond, 0.1));
+ }
+
+ @Test
+ void testRotationInverseKinematics() {
+ ChassisSpeeds speeds = new ChassisSpeeds(0, 0, 2 * Math.PI);
+ var moduleStates = m_kinematics.toWheelSpeeds(speeds);
+
+ assertAll(
+ () -> assertEquals(-150.79645, moduleStates.frontLeftMetersPerSecond, 0.1),
+ () -> assertEquals(150.79645, moduleStates.frontRightMetersPerSecond, 0.1),
+ () -> assertEquals(-150.79645, moduleStates.rearLeftMetersPerSecond, 0.1),
+ () -> assertEquals(150.79645, moduleStates.rearRightMetersPerSecond, 0.1));
+ }
+
+ @Test
+ void testRotationForwardKinematicsKinematics() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(-150.79645, 150.79645, -150.79645, 150.79645);
+ var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ assertAll(
+ () -> assertEquals(0, moduleStates.vxMetersPerSecond, 0.1),
+ () -> assertEquals(0, moduleStates.vyMetersPerSecond, 0.1),
+ () -> assertEquals(2 * Math.PI, moduleStates.omegaRadiansPerSecond, 0.1));
+ }
+
+ @Test
+ void testMixedTranslationRotationInverseKinematics() {
+ ChassisSpeeds speeds = new ChassisSpeeds(2, 3, 1);
+ var moduleStates = m_kinematics.toWheelSpeeds(speeds);
+
+ assertAll(
+ () -> assertEquals(-25.0, moduleStates.frontLeftMetersPerSecond, 0.1),
+ () -> assertEquals(29.0, moduleStates.frontRightMetersPerSecond, 0.1),
+ () -> assertEquals(-19.0, moduleStates.rearLeftMetersPerSecond, 0.1),
+ () -> assertEquals(23.0, moduleStates.rearRightMetersPerSecond, 0.1));
+ }
+
+ @Test
+ void testMixedTranslationRotationForwardKinematicsKinematics() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(-17.677670, 20.51, -13.44, 16.26);
+ var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ assertAll(
+ () -> assertEquals(1.413, moduleStates.vxMetersPerSecond, 0.1),
+ () -> assertEquals(2.122, moduleStates.vyMetersPerSecond, 0.1),
+ () -> assertEquals(0.707, moduleStates.omegaRadiansPerSecond, 0.1));
+ }
+
+ @Test
+ void testOffCenterRotationInverseKinematics() {
+ ChassisSpeeds speeds = new ChassisSpeeds(0, 0, 1);
+ var moduleStates = m_kinematics.toWheelSpeeds(speeds, m_fl);
+
+ assertAll(
+ () -> assertEquals(0, moduleStates.frontLeftMetersPerSecond, 0.1),
+ () -> assertEquals(24.0, moduleStates.frontRightMetersPerSecond, 0.1),
+ () -> assertEquals(-24.0, moduleStates.rearLeftMetersPerSecond, 0.1),
+ () -> assertEquals(48.0, moduleStates.rearRightMetersPerSecond, 0.1));
+ }
+
+ @Test
+ void testOffCenterRotationForwardKinematicsKinematics() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(0, 16.971, -16.971, 33.941);
+ var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ assertAll(
+ () -> assertEquals(8.48525, moduleStates.vxMetersPerSecond, 0.1),
+ () -> assertEquals(-8.48525, moduleStates.vyMetersPerSecond, 0.1),
+ () -> assertEquals(0.707, moduleStates.omegaRadiansPerSecond, 0.1));
+ }
+
+ @Test
+ void testOffCenterTranslationRotationInverseKinematics() {
+ ChassisSpeeds speeds = new ChassisSpeeds(5, 2, 1);
+ var moduleStates = m_kinematics.toWheelSpeeds(speeds, m_fl);
+
+ assertAll(
+ () -> assertEquals(3.0, moduleStates.frontLeftMetersPerSecond, 0.1),
+ () -> assertEquals(31.0, moduleStates.frontRightMetersPerSecond, 0.1),
+ () -> assertEquals(-17.0, moduleStates.rearLeftMetersPerSecond, 0.1),
+ () -> assertEquals(51.0, moduleStates.rearRightMetersPerSecond, 0.1));
+ }
+
+ @Test
+ void testOffCenterRotationTranslationForwardKinematicsKinematics() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(2.12, 21.92, -12.02, 36.06);
+ var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
+
+ assertAll(
+ () -> assertEquals(12.02, moduleStates.vxMetersPerSecond, 0.1),
+ () -> assertEquals(-7.07, moduleStates.vyMetersPerSecond, 0.1),
+ () -> assertEquals(0.707, moduleStates.omegaRadiansPerSecond, 0.1));
+ }
+
+ @Test
+ void testNormalize() {
+ var wheelSpeeds = new MecanumDriveWheelSpeeds(5, 6, 4, 7);
+ wheelSpeeds.normalize(5.5);
+
+ double factor = 5.5 / 7.0;
+
+ assertAll(
+ () -> assertEquals(5.0 * factor, wheelSpeeds.frontLeftMetersPerSecond, kEpsilon),
+ () -> assertEquals(6.0 * factor, wheelSpeeds.frontRightMetersPerSecond, kEpsilon),
+ () -> assertEquals(4.0 * factor, wheelSpeeds.rearLeftMetersPerSecond, kEpsilon),
+ () -> assertEquals(7.0 * factor, wheelSpeeds.rearRightMetersPerSecond, kEpsilon));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveOdometryTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/MecanumDriveOdometryTest.java
similarity index 64%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveOdometryTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/kinematics/MecanumDriveOdometryTest.java
index 5ece28b..3f43109 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveOdometryTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/MecanumDriveOdometryTest.java
@@ -1,33 +1,28 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
+package edu.wpi.first.math.kinematics;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import org.junit.jupiter.api.Test;
+
class MecanumDriveOdometryTest {
private final Translation2d m_fl = new Translation2d(12, 12);
private final Translation2d m_fr = new Translation2d(12, -12);
private final Translation2d m_bl = new Translation2d(-12, 12);
private final Translation2d m_br = new Translation2d(-12, -12);
-
private final MecanumDriveKinematics m_kinematics =
new MecanumDriveKinematics(m_fl, m_fr, m_bl, m_br);
- private final MecanumDriveOdometry m_odometry = new MecanumDriveOdometry(m_kinematics,
- new Rotation2d());
+ private final MecanumDriveOdometry m_odometry =
+ new MecanumDriveOdometry(m_kinematics, new Rotation2d());
@Test
void testMultipleConsecutiveUpdates() {
@@ -39,8 +34,7 @@
assertAll(
() -> assertEquals(secondPose.getX(), 0.0, 0.01),
() -> assertEquals(secondPose.getY(), 0.0, 0.01),
- () -> assertEquals(secondPose.getRotation().getDegrees(), 0.0, 0.01)
- );
+ () -> assertEquals(secondPose.getRotation().getDegrees(), 0.0, 0.01));
}
@Test
@@ -52,10 +46,9 @@
var pose = m_odometry.updateWithTime(0.10, new Rotation2d(), wheelSpeeds);
assertAll(
- () -> assertEquals(5.0 / 10.0, pose.getX(), 0.01),
- () -> assertEquals(0, pose.getY(), 0.01),
- () -> assertEquals(0.0, pose.getRotation().getDegrees(), 0.01)
- );
+ () -> assertEquals(0.3536, pose.getX(), 0.01),
+ () -> assertEquals(0.0, pose.getY(), 0.01),
+ () -> assertEquals(0.0, pose.getRotation().getDegrees(), 0.01));
}
@Test
@@ -68,10 +61,9 @@
final var pose = m_odometry.updateWithTime(1.0, Rotation2d.fromDegrees(90.0), wheelSpeeds);
assertAll(
- () -> assertEquals(12.0, pose.getX(), 0.01),
- () -> assertEquals(12.0, pose.getY(), 0.01),
- () -> assertEquals(90.0, pose.getRotation().getDegrees(), 0.01)
- );
+ () -> assertEquals(8.4855, pose.getX(), 0.01),
+ () -> assertEquals(8.4855, pose.getY(), 0.01),
+ () -> assertEquals(90.0, pose.getRotation().getDegrees(), 0.01));
}
@Test
@@ -79,16 +71,13 @@
var gyro = Rotation2d.fromDegrees(90.0);
var fieldAngle = Rotation2d.fromDegrees(0.0);
m_odometry.resetPosition(new Pose2d(new Translation2d(), fieldAngle), gyro);
- var speeds = new MecanumDriveWheelSpeeds(3.536, 3.536,
- 3.536, 3.536);
+ var speeds = new MecanumDriveWheelSpeeds(3.536, 3.536, 3.536, 3.536);
m_odometry.updateWithTime(0.0, gyro, new MecanumDriveWheelSpeeds());
var pose = m_odometry.updateWithTime(1.0, gyro, speeds);
assertAll(
- () -> assertEquals(5.0, pose.getX(), 0.1),
- () -> assertEquals(0.00, pose.getY(), 0.1),
- () -> assertEquals(0.00, pose.getRotation().getRadians(), 0.1)
- );
+ () -> assertEquals(3.536, pose.getX(), 0.1),
+ () -> assertEquals(0.0, pose.getY(), 0.1),
+ () -> assertEquals(0.0, pose.getRotation().getRadians(), 0.1));
}
-
}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveKinematicsTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveDriveKinematicsTest.java
similarity index 86%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveKinematicsTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveDriveKinematicsTest.java
index e9fbcd1..eec6aa2 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveKinematicsTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveDriveKinematicsTest.java
@@ -1,20 +1,16 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
+package edu.wpi.first.math.kinematics;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import org.junit.jupiter.api.Test;
+
class SwerveDriveKinematicsTest {
private static final double kEpsilon = 1E-9;
@@ -40,8 +36,7 @@
() -> assertEquals(0.0, moduleStates[0].angle.getRadians(), kEpsilon),
() -> assertEquals(0.0, moduleStates[1].angle.getRadians(), kEpsilon),
() -> assertEquals(0.0, moduleStates[2].angle.getRadians(), kEpsilon),
- () -> assertEquals(0.0, moduleStates[3].angle.getRadians(), kEpsilon)
- );
+ () -> assertEquals(0.0, moduleStates[3].angle.getRadians(), kEpsilon));
}
@Test
@@ -52,13 +47,11 @@
assertAll(
() -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
() -> assertEquals(5.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon)
- );
+ () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon));
}
@Test
void testStraightStrafeInverseKinematics() {
-
ChassisSpeeds speeds = new ChassisSpeeds(0, 5, 0);
var moduleStates = m_kinematics.toSwerveModuleStates(speeds);
@@ -70,8 +63,7 @@
() -> assertEquals(90.0, moduleStates[0].angle.getDegrees(), kEpsilon),
() -> assertEquals(90.0, moduleStates[1].angle.getDegrees(), kEpsilon),
() -> assertEquals(90.0, moduleStates[2].angle.getDegrees(), kEpsilon),
- () -> assertEquals(90.0, moduleStates[3].angle.getDegrees(), kEpsilon)
- );
+ () -> assertEquals(90.0, moduleStates[3].angle.getDegrees(), kEpsilon));
}
@Test
@@ -82,13 +74,11 @@
assertAll(
() -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
() -> assertEquals(5.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon)
- );
+ () -> assertEquals(0.0, chassisSpeeds.omegaRadiansPerSecond, kEpsilon));
}
@Test
void testTurnInPlaceInverseKinematics() {
-
ChassisSpeeds speeds = new ChassisSpeeds(0, 0, 2 * Math.PI);
var moduleStates = m_kinematics.toSwerveModuleStates(speeds);
@@ -107,8 +97,7 @@
() -> assertEquals(135.0, moduleStates[0].angle.getDegrees(), kEpsilon),
() -> assertEquals(45.0, moduleStates[1].angle.getDegrees(), kEpsilon),
() -> assertEquals(-135.0, moduleStates[2].angle.getDegrees(), kEpsilon),
- () -> assertEquals(-45.0, moduleStates[3].angle.getDegrees(), kEpsilon)
- );
+ () -> assertEquals(-45.0, moduleStates[3].angle.getDegrees(), kEpsilon));
}
@Test
@@ -123,13 +112,11 @@
assertAll(
() -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
() -> assertEquals(0.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(2 * Math.PI, chassisSpeeds.omegaRadiansPerSecond, 0.1)
- );
+ () -> assertEquals(2 * Math.PI, chassisSpeeds.omegaRadiansPerSecond, 0.1));
}
@Test
void testOffCenterCORRotationInverseKinematics() {
-
ChassisSpeeds speeds = new ChassisSpeeds(0, 0, 2 * Math.PI);
var moduleStates = m_kinematics.toSwerveModuleStates(speeds, m_fl);
@@ -150,8 +137,7 @@
() -> assertEquals(0.0, moduleStates[0].angle.getDegrees(), kEpsilon),
() -> assertEquals(0.0, moduleStates[1].angle.getDegrees(), kEpsilon),
() -> assertEquals(-90.0, moduleStates[2].angle.getDegrees(), kEpsilon),
- () -> assertEquals(-45.0, moduleStates[3].angle.getDegrees(), kEpsilon)
- );
+ () -> assertEquals(-45.0, moduleStates[3].angle.getDegrees(), kEpsilon));
}
@Test
@@ -175,38 +161,42 @@
assertAll(
() -> assertEquals(75.398, chassisSpeeds.vxMetersPerSecond, 0.1),
() -> assertEquals(-75.398, chassisSpeeds.vyMetersPerSecond, 0.1),
- () -> assertEquals(2 * Math.PI, chassisSpeeds.omegaRadiansPerSecond, 0.1)
- );
+ () -> assertEquals(2 * Math.PI, chassisSpeeds.omegaRadiansPerSecond, 0.1));
}
- private void assertModuleState(SwerveModuleState expected, SwerveModuleState actual,
- SwerveModuleState tolerance) {
+ private void assertModuleState(
+ SwerveModuleState expected, SwerveModuleState actual, SwerveModuleState tolerance) {
assertAll(
- () -> assertEquals(expected.speedMetersPerSecond, actual.speedMetersPerSecond,
- tolerance.speedMetersPerSecond),
- () -> assertEquals(expected.angle.getDegrees(), actual.angle.getDegrees(),
- tolerance.angle.getDegrees())
- );
+ () ->
+ assertEquals(
+ expected.speedMetersPerSecond,
+ actual.speedMetersPerSecond,
+ tolerance.speedMetersPerSecond),
+ () ->
+ assertEquals(
+ expected.angle.getDegrees(),
+ actual.angle.getDegrees(),
+ tolerance.angle.getDegrees()));
}
/**
- * Test the rotation of the robot about a non-central point with
- * both linear and angular velocities.
+ * Test the rotation of the robot about a non-central point with both linear and angular
+ * velocities.
*/
@Test
void testOffCenterCORRotationAndTranslationInverseKinematics() {
-
ChassisSpeeds speeds = new ChassisSpeeds(0.0, 3.0, 1.5);
var moduleStates = m_kinematics.toSwerveModuleStates(speeds, new Translation2d(24, 0));
// By equation (13.14) from state-space guide, our wheels/angles will be as follows,
// (+-1 degree or speed):
- SwerveModuleState[] expectedStates = new SwerveModuleState[]{
- new SwerveModuleState(23.43, Rotation2d.fromDegrees(-140.19)),
- new SwerveModuleState(23.43, Rotation2d.fromDegrees(-39.81)),
- new SwerveModuleState(54.08, Rotation2d.fromDegrees(-109.44)),
- new SwerveModuleState(54.08, Rotation2d.fromDegrees(-70.56))
- };
+ SwerveModuleState[] expectedStates =
+ new SwerveModuleState[] {
+ new SwerveModuleState(23.43, Rotation2d.fromDegrees(-140.19)),
+ new SwerveModuleState(23.43, Rotation2d.fromDegrees(-39.81)),
+ new SwerveModuleState(54.08, Rotation2d.fromDegrees(-109.44)),
+ new SwerveModuleState(54.08, Rotation2d.fromDegrees(-70.56))
+ };
var stateTolerance = new SwerveModuleState(0.1, Rotation2d.fromDegrees(0.1));
for (int i = 0; i < expectedStates.length; i++) {
@@ -235,8 +225,7 @@
assertAll(
() -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, 0.1),
() -> assertEquals(-33.0, chassisSpeeds.vyMetersPerSecond, 0.1),
- () -> assertEquals(1.5, chassisSpeeds.omegaRadiansPerSecond, 0.1)
- );
+ () -> assertEquals(1.5, chassisSpeeds.omegaRadiansPerSecond, 0.1));
}
@Test
@@ -255,8 +244,6 @@
() -> assertEquals(5.0 * factor, arr[0].speedMetersPerSecond, kEpsilon),
() -> assertEquals(6.0 * factor, arr[1].speedMetersPerSecond, kEpsilon),
() -> assertEquals(4.0 * factor, arr[2].speedMetersPerSecond, kEpsilon),
- () -> assertEquals(7.0 * factor, arr[3].speedMetersPerSecond, kEpsilon)
- );
+ () -> assertEquals(7.0 * factor, arr[3].speedMetersPerSecond, kEpsilon));
}
-
}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveOdometryTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveDriveOdometryTest.java
similarity index 62%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveOdometryTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveDriveOdometryTest.java
index f1ee907..cb6dfdf 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/SwerveDriveOdometryTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveDriveOdometryTest.java
@@ -1,21 +1,17 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
+package edu.wpi.first.math.kinematics;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import org.junit.jupiter.api.Test;
+
class SwerveDriveOdometryTest {
private final Translation2d m_fl = new Translation2d(12, 12);
private final Translation2d m_fr = new Translation2d(12, -12);
@@ -25,29 +21,32 @@
private final SwerveDriveKinematics m_kinematics =
new SwerveDriveKinematics(m_fl, m_fr, m_bl, m_br);
- private final SwerveDriveOdometry m_odometry = new SwerveDriveOdometry(m_kinematics,
- new Rotation2d());
+ private final SwerveDriveOdometry m_odometry =
+ new SwerveDriveOdometry(m_kinematics, new Rotation2d());
@Test
void testTwoIterations() {
// 5 units/sec in the x axis (forward)
final SwerveModuleState[] wheelSpeeds = {
- new SwerveModuleState(5, Rotation2d.fromDegrees(0)),
- new SwerveModuleState(5, Rotation2d.fromDegrees(0)),
- new SwerveModuleState(5, Rotation2d.fromDegrees(0)),
- new SwerveModuleState(5, Rotation2d.fromDegrees(0))
+ new SwerveModuleState(5, Rotation2d.fromDegrees(0)),
+ new SwerveModuleState(5, Rotation2d.fromDegrees(0)),
+ new SwerveModuleState(5, Rotation2d.fromDegrees(0)),
+ new SwerveModuleState(5, Rotation2d.fromDegrees(0))
};
- m_odometry.updateWithTime(0.0, new Rotation2d(),
- new SwerveModuleState(), new SwerveModuleState(),
- new SwerveModuleState(), new SwerveModuleState());
+ m_odometry.updateWithTime(
+ 0.0,
+ new Rotation2d(),
+ new SwerveModuleState(),
+ new SwerveModuleState(),
+ new SwerveModuleState(),
+ new SwerveModuleState());
var pose = m_odometry.updateWithTime(0.10, new Rotation2d(), wheelSpeeds);
assertAll(
() -> assertEquals(5.0 / 10.0, pose.getX(), 0.01),
() -> assertEquals(0, pose.getY(), 0.01),
- () -> assertEquals(0.0, pose.getRotation().getDegrees(), 0.01)
- );
+ () -> assertEquals(0.0, pose.getRotation().getDegrees(), 0.01));
}
@Test
@@ -59,10 +58,10 @@
// Module 3: speed 42.14888838624436 angle -26.565051177077986
final SwerveModuleState[] wheelSpeeds = {
- new SwerveModuleState(18.85, Rotation2d.fromDegrees(90.0)),
- new SwerveModuleState(42.15, Rotation2d.fromDegrees(26.565)),
- new SwerveModuleState(18.85, Rotation2d.fromDegrees(-90)),
- new SwerveModuleState(42.15, Rotation2d.fromDegrees(-26.565))
+ new SwerveModuleState(18.85, Rotation2d.fromDegrees(90.0)),
+ new SwerveModuleState(42.15, Rotation2d.fromDegrees(26.565)),
+ new SwerveModuleState(18.85, Rotation2d.fromDegrees(-90)),
+ new SwerveModuleState(42.15, Rotation2d.fromDegrees(-26.565))
};
final var zero = new SwerveModuleState();
@@ -72,8 +71,7 @@
assertAll(
() -> assertEquals(12.0, pose.getX(), 0.01),
() -> assertEquals(12.0, pose.getY(), 0.01),
- () -> assertEquals(90.0, pose.getRotation().getDegrees(), 0.01)
- );
+ () -> assertEquals(90.0, pose.getRotation().getDegrees(), 0.01));
}
@Test
@@ -89,8 +87,6 @@
assertAll(
() -> assertEquals(1.0, pose.getX(), 0.1),
() -> assertEquals(0.00, pose.getY(), 0.1),
- () -> assertEquals(0.00, pose.getRotation().getRadians(), 0.1)
- );
+ () -> assertEquals(0.00, pose.getRotation().getRadians(), 0.1));
}
-
}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveModuleStateTest.java b/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveModuleStateTest.java
new file mode 100644
index 0000000..01815be
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/kinematics/SwerveModuleStateTest.java
@@ -0,0 +1,53 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.kinematics;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.geometry.Rotation2d;
+import org.junit.jupiter.api.Test;
+
+class SwerveModuleStateTest {
+ private static final double kEpsilon = 1E-9;
+
+ @Test
+ void testOptimize() {
+ var angleA = Rotation2d.fromDegrees(45);
+ var refA = new SwerveModuleState(-2.0, Rotation2d.fromDegrees(180));
+ var optimizedA = SwerveModuleState.optimize(refA, angleA);
+
+ assertAll(
+ () -> assertEquals(2.0, optimizedA.speedMetersPerSecond, kEpsilon),
+ () -> assertEquals(0.0, optimizedA.angle.getDegrees(), kEpsilon));
+
+ var angleB = Rotation2d.fromDegrees(-50);
+ var refB = new SwerveModuleState(4.7, Rotation2d.fromDegrees(41));
+ var optimizedB = SwerveModuleState.optimize(refB, angleB);
+
+ assertAll(
+ () -> assertEquals(-4.7, optimizedB.speedMetersPerSecond, kEpsilon),
+ () -> assertEquals(-139.0, optimizedB.angle.getDegrees(), kEpsilon));
+ }
+
+ @Test
+ void testNoOptimize() {
+ var angleA = Rotation2d.fromDegrees(0);
+ var refA = new SwerveModuleState(2.0, Rotation2d.fromDegrees(89));
+ var optimizedA = SwerveModuleState.optimize(refA, angleA);
+
+ assertAll(
+ () -> assertEquals(2.0, optimizedA.speedMetersPerSecond, kEpsilon),
+ () -> assertEquals(89.0, optimizedA.angle.getDegrees(), kEpsilon));
+
+ var angleB = Rotation2d.fromDegrees(0);
+ var refB = new SwerveModuleState(-2.0, Rotation2d.fromDegrees(-2));
+ var optimizedB = SwerveModuleState.optimize(refB, angleB);
+
+ assertAll(
+ () -> assertEquals(-2.0, optimizedB.speedMetersPerSecond, kEpsilon),
+ () -> assertEquals(-2.0, optimizedB.angle.getDegrees(), kEpsilon));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/spline/CubicHermiteSplineTest.java b/wpimath/src/test/java/edu/wpi/first/math/spline/CubicHermiteSplineTest.java
new file mode 100644
index 0000000..dd08e45
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/spline/CubicHermiteSplineTest.java
@@ -0,0 +1,162 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.spline;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.spline.SplineParameterizer.MalformedSplineException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class CubicHermiteSplineTest {
+ private static final double kMaxDx = 0.127;
+ private static final double kMaxDy = 0.00127;
+ private static final double kMaxDtheta = 0.0872;
+
+ @SuppressWarnings("ParameterName")
+ private void run(Pose2d a, List<Translation2d> waypoints, Pose2d b) {
+ // Start the timer.
+ // var start = System.nanoTime();
+
+ // Generate and parameterize the spline.
+ var controlVectors =
+ SplineHelper.getCubicControlVectorsFromWaypoints(
+ a, waypoints.toArray(new Translation2d[0]), b);
+ var splines =
+ SplineHelper.getCubicSplinesFromControlVectors(
+ controlVectors[0], waypoints.toArray(new Translation2d[0]), controlVectors[1]);
+
+ var poses = new ArrayList<PoseWithCurvature>();
+
+ poses.add(splines[0].getPoint(0.0));
+
+ for (var spline : splines) {
+ poses.addAll(SplineParameterizer.parameterize(spline));
+ }
+
+ // End the timer.
+ // var end = System.nanoTime();
+
+ // Calculate the duration (used when benchmarking)
+ // var durationMicroseconds = (end - start) / 1000.0;
+
+ for (int i = 0; i < poses.size() - 1; i++) {
+ var p0 = poses.get(i);
+ var p1 = poses.get(i + 1);
+
+ // Make sure the twist is under the tolerance defined by the Spline class.
+ var twist = p0.poseMeters.log(p1.poseMeters);
+ assertAll(
+ () -> assertTrue(Math.abs(twist.dx) < kMaxDx),
+ () -> assertTrue(Math.abs(twist.dy) < kMaxDy),
+ () -> assertTrue(Math.abs(twist.dtheta) < kMaxDtheta));
+ }
+
+ // Check first point
+ assertAll(
+ () -> assertEquals(a.getX(), poses.get(0).poseMeters.getX(), 1E-9),
+ () -> assertEquals(a.getY(), poses.get(0).poseMeters.getY(), 1E-9),
+ () ->
+ assertEquals(
+ a.getRotation().getRadians(),
+ poses.get(0).poseMeters.getRotation().getRadians(),
+ 1E-9));
+
+ // Check interior waypoints
+ boolean interiorsGood = true;
+ for (var waypoint : waypoints) {
+ boolean found = false;
+ for (var state : poses) {
+ if (waypoint.getDistance(state.poseMeters.getTranslation()) == 0) {
+ found = true;
+ }
+ }
+ interiorsGood &= found;
+ }
+
+ assertTrue(interiorsGood);
+
+ // Check last point
+ assertAll(
+ () -> assertEquals(b.getX(), poses.get(poses.size() - 1).poseMeters.getX(), 1E-9),
+ () -> assertEquals(b.getY(), poses.get(poses.size() - 1).poseMeters.getY(), 1E-9),
+ () ->
+ assertEquals(
+ b.getRotation().getRadians(),
+ poses.get(poses.size() - 1).poseMeters.getRotation().getRadians(),
+ 1E-9));
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testStraightLine() {
+ run(new Pose2d(), new ArrayList<>(), new Pose2d(3, 0, new Rotation2d()));
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testSCurve() {
+ var start = new Pose2d(0, 0, Rotation2d.fromDegrees(90.0));
+ ArrayList<Translation2d> waypoints = new ArrayList<>();
+ waypoints.add(new Translation2d(1, 1));
+ waypoints.add(new Translation2d(2, -1));
+ var end = new Pose2d(3, 0, Rotation2d.fromDegrees(90.0));
+
+ run(start, waypoints, end);
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testOneInterior() {
+ var start = new Pose2d(0, 0, Rotation2d.fromDegrees(0.0));
+ ArrayList<Translation2d> waypoints = new ArrayList<>();
+ waypoints.add(new Translation2d(2.0, 0.0));
+ var end = new Pose2d(4, 0, Rotation2d.fromDegrees(0.0));
+
+ run(start, waypoints, end);
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testWindyPath() {
+ final var start = new Pose2d(0, 0, Rotation2d.fromDegrees(0.0));
+ final ArrayList<Translation2d> waypoints = new ArrayList<>();
+ waypoints.add(new Translation2d(0.5, 0.5));
+ waypoints.add(new Translation2d(0.5, 0.5));
+ waypoints.add(new Translation2d(1.0, 0.0));
+ waypoints.add(new Translation2d(1.5, 0.5));
+ waypoints.add(new Translation2d(2.0, 0.0));
+ waypoints.add(new Translation2d(2.5, 0.5));
+ final var end = new Pose2d(3.0, 0.0, Rotation2d.fromDegrees(0.0));
+
+ run(start, waypoints, end);
+ }
+
+ @Test
+ void testMalformed() {
+ assertThrows(
+ MalformedSplineException.class,
+ () ->
+ run(
+ new Pose2d(0, 0, Rotation2d.fromDegrees(0)),
+ new ArrayList<>(),
+ new Pose2d(1, 0, Rotation2d.fromDegrees(180))));
+ assertThrows(
+ MalformedSplineException.class,
+ () ->
+ run(
+ new Pose2d(10, 10, Rotation2d.fromDegrees(90)),
+ Arrays.asList(new Translation2d(10, 10.5)),
+ new Pose2d(10, 11, Rotation2d.fromDegrees(-90))));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/spline/QuinticHermiteSplineTest.java b/wpimath/src/test/java/edu/wpi/first/math/spline/QuinticHermiteSplineTest.java
new file mode 100644
index 0000000..8367070
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/spline/QuinticHermiteSplineTest.java
@@ -0,0 +1,106 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.spline;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.spline.SplineParameterizer.MalformedSplineException;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class QuinticHermiteSplineTest {
+ private static final double kMaxDx = 0.127;
+ private static final double kMaxDy = 0.00127;
+ private static final double kMaxDtheta = 0.0872;
+
+ @SuppressWarnings("ParameterName")
+ private void run(Pose2d a, Pose2d b) {
+ // Start the timer.
+ // var start = System.nanoTime();
+
+ // Generate and parameterize the spline.
+ var spline = SplineHelper.getQuinticSplinesFromWaypoints(List.of(a, b))[0];
+ var poses = SplineParameterizer.parameterize(spline);
+
+ // End the timer.
+ // var end = System.nanoTime();
+
+ // Calculate the duration (used when benchmarking)
+ // var durationMicroseconds = (end - start) / 1000.0;
+
+ for (int i = 0; i < poses.size() - 1; i++) {
+ var p0 = poses.get(i);
+ var p1 = poses.get(i + 1);
+
+ // Make sure the twist is under the tolerance defined by the Spline class.
+ var twist = p0.poseMeters.log(p1.poseMeters);
+ assertAll(
+ () -> assertTrue(Math.abs(twist.dx) < kMaxDx),
+ () -> assertTrue(Math.abs(twist.dy) < kMaxDy),
+ () -> assertTrue(Math.abs(twist.dtheta) < kMaxDtheta));
+ }
+
+ // Check first point
+ assertAll(
+ () -> assertEquals(a.getX(), poses.get(0).poseMeters.getX(), 1E-9),
+ () -> assertEquals(a.getY(), poses.get(0).poseMeters.getY(), 1E-9),
+ () ->
+ assertEquals(
+ a.getRotation().getRadians(),
+ poses.get(0).poseMeters.getRotation().getRadians(),
+ 1E-9));
+
+ // Check last point
+ assertAll(
+ () -> assertEquals(b.getX(), poses.get(poses.size() - 1).poseMeters.getX(), 1E-9),
+ () -> assertEquals(b.getY(), poses.get(poses.size() - 1).poseMeters.getY(), 1E-9),
+ () ->
+ assertEquals(
+ b.getRotation().getRadians(),
+ poses.get(poses.size() - 1).poseMeters.getRotation().getRadians(),
+ 1E-9));
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testStraightLine() {
+ run(new Pose2d(), new Pose2d(3, 0, new Rotation2d()));
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testSimpleSCurve() {
+ run(new Pose2d(), new Pose2d(1, 1, new Rotation2d()));
+ }
+
+ @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
+ @Test
+ void testSquiggly() {
+ run(
+ new Pose2d(0, 0, Rotation2d.fromDegrees(90)),
+ new Pose2d(-1, 0, Rotation2d.fromDegrees(90)));
+ }
+
+ @Test
+ void testMalformed() {
+ assertThrows(
+ MalformedSplineException.class,
+ () ->
+ run(
+ new Pose2d(0, 0, Rotation2d.fromDegrees(0)),
+ new Pose2d(1, 0, Rotation2d.fromDegrees(180))));
+ assertThrows(
+ MalformedSplineException.class,
+ () ->
+ run(
+ new Pose2d(10, 10, Rotation2d.fromDegrees(90)),
+ new Pose2d(10, 11, Rotation2d.fromDegrees(-90))));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/system/DiscretizationTest.java b/wpimath/src/test/java/edu/wpi/first/math/system/DiscretizationTest.java
new file mode 100644
index 0000000..add1afb
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/system/DiscretizationTest.java
@@ -0,0 +1,217 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.MatBuilder;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.numbers.N2;
+import org.junit.jupiter.api.Test;
+
+public class DiscretizationTest {
+ // Check that for a simple second-order system that we can easily analyze
+ // analytically,
+ @Test
+ public void testDiscretizeA() {
+ final var contA = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0, 1, 0, 0);
+ final var x0 = VecBuilder.fill(1, 1);
+
+ final var discA = Discretization.discretizeA(contA, 1.0);
+ final var x1Discrete = discA.times(x0);
+
+ // We now have pos = vel = 1 and accel = 0, which should give us:
+ final var x1Truth =
+ VecBuilder.fill(
+ 1.0 * x0.get(0, 0) + 1.0 * x0.get(1, 0), 0.0 * x0.get(0, 0) + 1.0 * x0.get(1, 0));
+
+ assertEquals(x1Truth, x1Discrete);
+ }
+
+ // Check that for a simple second-order system that we can easily analyze
+ // analytically,
+ @Test
+ public void testDiscretizeAB() {
+ final var contA = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0, 1, 0, 0);
+ final var contB = new MatBuilder<>(Nat.N2(), Nat.N1()).fill(0, 1);
+
+ final var x0 = VecBuilder.fill(1, 1);
+ final var u = VecBuilder.fill(1);
+
+ var discABPair = Discretization.discretizeAB(contA, contB, 1.0);
+ var discA = discABPair.getFirst();
+ var discB = discABPair.getSecond();
+
+ var x1Discrete = discA.times(x0).plus(discB.times(u));
+
+ // We now have pos = vel = accel = 1, which should give us:
+ final var x1Truth =
+ VecBuilder.fill(
+ 1.0 * x0.get(0, 0) + 1.0 * x0.get(1, 0) + 0.5 * u.get(0, 0),
+ 0.0 * x0.get(0, 0) + 1.0 * x0.get(1, 0) + 1.0 * u.get(0, 0));
+
+ assertEquals(x1Truth, x1Discrete);
+ }
+
+ // dt
+ // Test that the discrete approximation of Q ≈ ∫ e^(Aτ) Q e^(Aᵀτ) dτ
+ // 0
+ @Test
+ public void testDiscretizeSlowModelAQ() {
+ final var contA = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0, 1, 0, 0);
+ final var contQ = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1);
+
+ final double dt = 1.0;
+
+ final var discQIntegrated =
+ RungeKuttaTimeVarying.rungeKuttaTimeVarying(
+ (Double t, Matrix<N2, N2> x) ->
+ contA.times(t).exp().times(contQ).times(contA.transpose().times(t).exp()),
+ 0.0,
+ new Matrix<>(Nat.N2(), Nat.N2()),
+ dt);
+
+ var discAQPair = Discretization.discretizeAQ(contA, contQ, dt);
+ var discQ = discAQPair.getSecond();
+
+ assertTrue(
+ discQIntegrated.minus(discQ).normF() < 1e-10,
+ "Expected these to be nearly equal:\ndiscQ:\n"
+ + discQ
+ + "\ndiscQIntegrated:\n"
+ + discQIntegrated);
+ }
+
+ // dt
+ // Test that the discrete approximation of Q ≈ ∫ e^(Aτ) Q e^(Aᵀτ) dτ
+ // 0
+ @Test
+ public void testDiscretizeFastModelAQ() {
+ final var contA = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0, 1, 0, -1406.29);
+ final var contQ = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0.0025, 0, 0, 1);
+
+ final var dt = 0.005;
+
+ final var discQIntegrated =
+ RungeKuttaTimeVarying.rungeKuttaTimeVarying(
+ (Double t, Matrix<N2, N2> x) ->
+ contA.times(t).exp().times(contQ).times(contA.transpose().times(t).exp()),
+ 0.0,
+ new Matrix<>(Nat.N2(), Nat.N2()),
+ dt);
+
+ var discAQPair = Discretization.discretizeAQ(contA, contQ, dt);
+ var discQ = discAQPair.getSecond();
+
+ assertTrue(
+ discQIntegrated.minus(discQ).normF() < 1e-3,
+ "Expected these to be nearly equal:\ndiscQ:\n"
+ + discQ
+ + "\ndiscQIntegrated:\n"
+ + discQIntegrated);
+ }
+
+ // Test that the Taylor series discretization produces nearly identical results.
+ @Test
+ public void testDiscretizeSlowModelAQTaylor() {
+ final var contA = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0, 1, 0, 0);
+ final var contQ = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1);
+
+ final var dt = 1.0;
+
+ // Continuous Q should be positive semidefinite
+ final var esCont = contQ.getStorage().eig();
+ for (int i = 0; i < contQ.getNumRows(); ++i) {
+ assertTrue(esCont.getEigenvalue(i).real >= 0);
+ }
+
+ final var discQIntegrated =
+ RungeKuttaTimeVarying.rungeKuttaTimeVarying(
+ (Double t, Matrix<N2, N2> x) ->
+ contA.times(t).exp().times(contQ).times(contA.transpose().times(t).exp()),
+ 0.0,
+ new Matrix<>(Nat.N2(), Nat.N2()),
+ dt);
+
+ var discA = Discretization.discretizeA(contA, dt);
+
+ var discAQPair = Discretization.discretizeAQ(contA, contQ, dt);
+ var discATaylor = discAQPair.getFirst();
+ var discQTaylor = discAQPair.getSecond();
+
+ assertTrue(
+ discQIntegrated.minus(discQTaylor).normF() < 1e-10,
+ "Expected these to be nearly equal:\ndiscQTaylor:\n"
+ + discQTaylor
+ + "\ndiscQIntegrated:\n"
+ + discQIntegrated);
+ assertTrue(discA.minus(discATaylor).normF() < 1e-10);
+
+ // Discrete Q should be positive semidefinite
+ final var esDisc = discQTaylor.getStorage().eig();
+ for (int i = 0; i < discQTaylor.getNumRows(); ++i) {
+ assertTrue(esDisc.getEigenvalue(i).real >= 0);
+ }
+ }
+
+ // Test that the Taylor series discretization produces nearly identical results.
+ @Test
+ public void testDiscretizeFastModelAQTaylor() {
+ final var contA = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0, 1, 0, -1500);
+ final var contQ = new MatBuilder<>(Nat.N2(), Nat.N2()).fill(0.0025, 0, 0, 1);
+
+ final var dt = 0.005;
+
+ // Continuous Q should be positive semidefinite
+ final var esCont = contQ.getStorage().eig();
+ for (int i = 0; i < contQ.getNumRows(); ++i) {
+ assertTrue(esCont.getEigenvalue(i).real >= 0);
+ }
+
+ final var discQIntegrated =
+ RungeKuttaTimeVarying.rungeKuttaTimeVarying(
+ (Double t, Matrix<N2, N2> x) ->
+ contA.times(t).exp().times(contQ).times(contA.transpose().times(t).exp()),
+ 0.0,
+ new Matrix<>(Nat.N2(), Nat.N2()),
+ dt);
+
+ var discA = Discretization.discretizeA(contA, dt);
+
+ var discAQPair = Discretization.discretizeAQ(contA, contQ, dt);
+ var discATaylor = discAQPair.getFirst();
+ var discQTaylor = discAQPair.getSecond();
+
+ assertTrue(
+ discQIntegrated.minus(discQTaylor).normF() < 1e-3,
+ "Expected these to be nearly equal:\ndiscQTaylor:\n"
+ + discQTaylor
+ + "\ndiscQIntegrated:\n"
+ + discQIntegrated);
+ assertTrue(discA.minus(discATaylor).normF() < 1e-10);
+
+ // Discrete Q should be positive semidefinite
+ final var esDisc = discQTaylor.getStorage().eig();
+ for (int i = 0; i < discQTaylor.getNumRows(); ++i) {
+ assertTrue(esDisc.getEigenvalue(i).real >= 0);
+ }
+ }
+
+ // Test that DiscretizeR() works
+ @Test
+ public void testDiscretizeR() {
+ var contR = Matrix.mat(Nat.N2(), Nat.N2()).fill(2.0, 0.0, 0.0, 1.0);
+ var discRTruth = Matrix.mat(Nat.N2(), Nat.N2()).fill(4.0, 0.0, 0.0, 2.0);
+
+ var discR = Discretization.discretizeR(contR, 0.5);
+
+ assertTrue(
+ discRTruth.minus(discR).normF() < 1e-10,
+ "Expected these to be nearly equal:\ndiscR:\n" + discR + "\ndiscRTruth:\n" + discRTruth);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/system/LinearSystemIDTest.java b/wpimath/src/test/java/edu/wpi/first/math/system/LinearSystemIDTest.java
new file mode 100644
index 0000000..37cb1ec
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/system/LinearSystemIDTest.java
@@ -0,0 +1,91 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.system.plant.DCMotor;
+import edu.wpi.first.math.system.plant.LinearSystemId;
+import org.junit.jupiter.api.Test;
+
+class LinearSystemIDTest {
+ @Test
+ public void testDrivetrainVelocitySystem() {
+ var model =
+ LinearSystemId.createDrivetrainVelocitySystem(DCMotor.getNEO(4), 70, 0.05, 0.4, 6.0, 6);
+ assertTrue(
+ model
+ .getA()
+ .isEqual(
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(-10.14132, 3.06598, 3.06598, -10.14132),
+ 0.001));
+
+ assertTrue(
+ model
+ .getB()
+ .isEqual(
+ Matrix.mat(Nat.N2(), Nat.N2()).fill(4.2590, -1.28762, -1.2876, 4.2590), 0.001));
+
+ assertTrue(
+ model.getC().isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 0.0, 0.0, 1.0), 0.001));
+
+ assertTrue(
+ model.getD().isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 0.0, 0.0, 0.0), 0.001));
+ }
+
+ @Test
+ public void testElevatorSystem() {
+ var model = LinearSystemId.createElevatorSystem(DCMotor.getNEO(2), 5, 0.05, 12);
+ assertTrue(
+ model.getA().isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1, 0, -99.05473), 0.001));
+
+ assertTrue(model.getB().isEqual(VecBuilder.fill(0, 20.8), 0.001));
+
+ assertTrue(model.getC().isEqual(Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0), 0.001));
+
+ assertTrue(model.getD().isEqual(VecBuilder.fill(0), 0.001));
+ }
+
+ @Test
+ public void testFlywheelSystem() {
+ var model = LinearSystemId.createFlywheelSystem(DCMotor.getNEO(2), 0.00032, 1.0);
+ assertTrue(model.getA().isEqual(VecBuilder.fill(-26.87032), 0.001));
+
+ assertTrue(model.getB().isEqual(VecBuilder.fill(1354.166667), 0.001));
+
+ assertTrue(model.getC().isEqual(VecBuilder.fill(1), 0.001));
+
+ assertTrue(model.getD().isEqual(VecBuilder.fill(0), 0.001));
+ }
+
+ @Test
+ public void testIdentifyPositionSystem() {
+ // By controls engineering in frc,
+ // x-dot = [0 1 | 0 -kv/ka] x = [0 | 1/ka] u
+ var kv = 1.0;
+ var ka = 0.5;
+ var model = LinearSystemId.identifyPositionSystem(kv, ka);
+
+ assertEquals(model.getA(), Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1, 0, -kv / ka));
+ assertEquals(model.getB(), VecBuilder.fill(0, 1 / ka));
+ }
+
+ @Test
+ public void testIdentifyVelocitySystem() {
+ // By controls engineering in frc,
+ // V = kv * velocity + ka * acceleration
+ // x-dot = -kv/ka * v + 1/ka \cdot V
+ var kv = 1.0;
+ var ka = 0.5;
+ var model = LinearSystemId.identifyVelocitySystem(kv, ka);
+
+ assertEquals(model.getA(), VecBuilder.fill(-kv / ka));
+ assertEquals(model.getB(), VecBuilder.fill(1 / ka));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/system/NumericalIntegrationTest.java b/wpimath/src/test/java/edu/wpi/first/math/system/NumericalIntegrationTest.java
new file mode 100644
index 0000000..b3fe7e6
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/system/NumericalIntegrationTest.java
@@ -0,0 +1,68 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.VecBuilder;
+import edu.wpi.first.math.numbers.N1;
+import org.junit.jupiter.api.Test;
+
+public class NumericalIntegrationTest {
+ @Test
+ public void testExponential() {
+ Matrix<N1, N1> y0 = VecBuilder.fill(0.0);
+
+ var y1 =
+ NumericalIntegration.rk4(
+ (Matrix<N1, N1> x) -> {
+ var y = new Matrix<>(Nat.N1(), Nat.N1());
+ y.set(0, 0, Math.exp(x.get(0, 0)));
+ return y;
+ },
+ y0,
+ 0.1);
+
+ assertEquals(Math.exp(0.1) - Math.exp(0.0), y1.get(0, 0), 1e-3);
+ }
+
+ @Test
+ public void testExponentialRKF45() {
+ Matrix<N1, N1> y0 = VecBuilder.fill(0.0);
+
+ var y1 =
+ NumericalIntegration.rkf45(
+ (x, u) -> {
+ var y = new Matrix<>(Nat.N1(), Nat.N1());
+ y.set(0, 0, Math.exp(x.get(0, 0)));
+ return y;
+ },
+ y0,
+ VecBuilder.fill(0),
+ 0.1);
+
+ assertEquals(Math.exp(0.1) - Math.exp(0.0), y1.get(0, 0), 1e-3);
+ }
+
+ @Test
+ public void testExponentialRKDP() {
+ Matrix<N1, N1> y0 = VecBuilder.fill(0.0);
+
+ var y1 =
+ NumericalIntegration.rkdp(
+ (x, u) -> {
+ var y = new Matrix<>(Nat.N1(), Nat.N1());
+ y.set(0, 0, Math.exp(x.get(0, 0)));
+ return y;
+ },
+ y0,
+ VecBuilder.fill(0),
+ 0.1);
+
+ assertEquals(Math.exp(0.1) - Math.exp(0.0), y1.get(0, 0), 1e-3);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/system/RungeKuttaTimeVarying.java b/wpimath/src/test/java/edu/wpi/first/math/system/RungeKuttaTimeVarying.java
new file mode 100644
index 0000000..7b5e844
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/system/RungeKuttaTimeVarying.java
@@ -0,0 +1,41 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Num;
+import java.util.function.BiFunction;
+
+public final class RungeKuttaTimeVarying {
+ private RungeKuttaTimeVarying() {
+ // Utility class
+ }
+
+ /**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(t, y) for dt.
+ *
+ * @param <Rows> Rows in y.
+ * @param <Cols> Columns in y.
+ * @param f The function to integrate. It must take two arguments t and y.
+ * @param t The initial value of t.
+ * @param y The initial value of y.
+ * @param dtSeconds The time over which to integrate.
+ */
+ @SuppressWarnings("MethodTypeParameterName")
+ public static <Rows extends Num, Cols extends Num> Matrix<Rows, Cols> rungeKuttaTimeVarying(
+ BiFunction<Double, Matrix<Rows, Cols>, Matrix<Rows, Cols>> f,
+ double t,
+ Matrix<Rows, Cols> y,
+ double dtSeconds) {
+ final var h = dtSeconds;
+
+ Matrix<Rows, Cols> k1 = f.apply(t, y);
+ Matrix<Rows, Cols> k2 = f.apply(t + dtSeconds * 0.5, y.plus(k1.times(h * 0.5)));
+ Matrix<Rows, Cols> k3 = f.apply(t + dtSeconds * 0.5, y.plus(k2.times(h * 0.5)));
+ Matrix<Rows, Cols> k4 = f.apply(t + dtSeconds, y.plus(k3.times(h)));
+
+ return y.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/system/RungeKuttaTimeVaryingTest.java b/wpimath/src/test/java/edu/wpi/first/math/system/RungeKuttaTimeVaryingTest.java
new file mode 100644
index 0000000..ee843ab
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/system/RungeKuttaTimeVaryingTest.java
@@ -0,0 +1,43 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.system;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.MatBuilder;
+import edu.wpi.first.math.Matrix;
+import edu.wpi.first.math.Nat;
+import edu.wpi.first.math.numbers.N1;
+import org.junit.jupiter.api.Test;
+
+public class RungeKuttaTimeVaryingTest {
+ private static Matrix<N1, N1> rungeKuttaTimeVaryingSolution(double t) {
+ return new MatBuilder<>(Nat.N1(), Nat.N1())
+ .fill(12.0 * Math.exp(t) / (Math.pow(Math.exp(t) + 1.0, 2.0)));
+ }
+
+ // Tests RK4 with a time varying solution. From
+ // http://www2.hawaii.edu/~jmcfatri/math407/RungeKuttaTest.html:
+ // x' = x (2 / (e^t + 1) - 1)
+ //
+ // The true (analytical) solution is:
+ //
+ // x(t) = 12 * e^t / ((e^t + 1)^2)
+ @Test
+ public void rungeKuttaTimeVaryingTest() {
+ final var y0 = rungeKuttaTimeVaryingSolution(5.0);
+
+ final var y1 =
+ RungeKuttaTimeVarying.rungeKuttaTimeVarying(
+ (Double t, Matrix<N1, N1> x) -> {
+ return new MatBuilder<>(Nat.N1(), Nat.N1())
+ .fill(x.get(0, 0) * (2.0 / (Math.exp(t) + 1.0) - 1.0));
+ },
+ 5.0,
+ y0,
+ 1.0);
+ assertEquals(rungeKuttaTimeVaryingSolution(6.0).get(0, 0), y1.get(0, 0), 1e-3);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/CentripetalAccelerationConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/CentripetalAccelerationConstraintTest.java
new file mode 100644
index 0000000..1805589
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/CentripetalAccelerationConstraintTest.java
@@ -0,0 +1,37 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.trajectory.constraint.CentripetalAccelerationConstraint;
+import edu.wpi.first.math.util.Units;
+import java.util.Collections;
+import org.junit.jupiter.api.Test;
+
+class CentripetalAccelerationConstraintTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testCentripetalAccelerationConstraint() {
+ double maxCentripetalAcceleration = Units.feetToMeters(7.0); // 7 feet per second squared
+ var constraint = new CentripetalAccelerationConstraint(maxCentripetalAcceleration);
+
+ Trajectory trajectory =
+ TrajectoryGeneratorTest.getTrajectory(Collections.singletonList(constraint));
+
+ var duration = trajectory.getTotalTimeSeconds();
+ var t = 0.0;
+ var dt = 0.02;
+
+ while (t < duration) {
+ var point = trajectory.sample(t);
+ var centripetalAcceleration =
+ Math.pow(point.velocityMetersPerSecond, 2) * point.curvatureRadPerMeter;
+
+ t += dt;
+ assertTrue(centripetalAcceleration <= maxCentripetalAcceleration + 0.05);
+ }
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/DifferentialDriveKinematicsConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/DifferentialDriveKinematicsConstraintTest.java
new file mode 100644
index 0000000..4c8a8ba
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/DifferentialDriveKinematicsConstraintTest.java
@@ -0,0 +1,48 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+import edu.wpi.first.math.trajectory.constraint.DifferentialDriveKinematicsConstraint;
+import edu.wpi.first.math.util.Units;
+import java.util.Collections;
+import org.junit.jupiter.api.Test;
+
+class DifferentialDriveKinematicsConstraintTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testDifferentialDriveKinematicsConstraint() {
+ double maxVelocity = Units.feetToMeters(12.0); // 12 feet per second
+ var kinematics = new DifferentialDriveKinematics(Units.inchesToMeters(27));
+ var constraint = new DifferentialDriveKinematicsConstraint(kinematics, maxVelocity);
+
+ Trajectory trajectory =
+ TrajectoryGeneratorTest.getTrajectory(Collections.singletonList(constraint));
+
+ var duration = trajectory.getTotalTimeSeconds();
+ var t = 0.0;
+ var dt = 0.02;
+
+ while (t < duration) {
+ var point = trajectory.sample(t);
+ var chassisSpeeds =
+ new ChassisSpeeds(
+ point.velocityMetersPerSecond,
+ 0,
+ point.velocityMetersPerSecond * point.curvatureRadPerMeter);
+
+ var wheelSpeeds = kinematics.toWheelSpeeds(chassisSpeeds);
+
+ t += dt;
+ assertAll(
+ () -> assertTrue(wheelSpeeds.leftMetersPerSecond <= maxVelocity + 0.05),
+ () -> assertTrue(wheelSpeeds.rightMetersPerSecond <= maxVelocity + 0.05));
+ }
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/DifferentialDriveVoltageConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/DifferentialDriveVoltageConstraintTest.java
new file mode 100644
index 0000000..87c1bd9
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/DifferentialDriveVoltageConstraintTest.java
@@ -0,0 +1,105 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.controller.SimpleMotorFeedforward;
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.kinematics.ChassisSpeeds;
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+import edu.wpi.first.math.trajectory.constraint.DifferentialDriveVoltageConstraint;
+import java.util.ArrayList;
+import java.util.Collections;
+import org.junit.jupiter.api.Test;
+
+class DifferentialDriveVoltageConstraintTest {
+ @SuppressWarnings("LocalVariableName")
+ @Test
+ void testDifferentialDriveVoltageConstraint() {
+ // Pick an unreasonably large kA to ensure the constraint has to do some work
+ var feedforward = new SimpleMotorFeedforward(1, 1, 3);
+ var kinematics = new DifferentialDriveKinematics(0.5);
+ double maxVoltage = 10;
+ var constraint = new DifferentialDriveVoltageConstraint(feedforward, kinematics, maxVoltage);
+
+ Trajectory trajectory =
+ TrajectoryGeneratorTest.getTrajectory(Collections.singletonList(constraint));
+
+ var duration = trajectory.getTotalTimeSeconds();
+ var t = 0.0;
+ var dt = 0.02;
+
+ while (t < duration) {
+ var point = trajectory.sample(t);
+ var chassisSpeeds =
+ new ChassisSpeeds(
+ point.velocityMetersPerSecond,
+ 0,
+ point.velocityMetersPerSecond * point.curvatureRadPerMeter);
+
+ var wheelSpeeds = kinematics.toWheelSpeeds(chassisSpeeds);
+
+ t += dt;
+
+ // Not really a strictly-correct test as we're using the chassis accel instead of the
+ // wheel accel, but much easier than doing it "properly" and a reasonable check anyway
+ assertAll(
+ () ->
+ assertTrue(
+ feedforward.calculate(
+ wheelSpeeds.leftMetersPerSecond, point.accelerationMetersPerSecondSq)
+ <= maxVoltage + 0.05),
+ () ->
+ assertTrue(
+ feedforward.calculate(
+ wheelSpeeds.leftMetersPerSecond, point.accelerationMetersPerSecondSq)
+ >= -maxVoltage - 0.05),
+ () ->
+ assertTrue(
+ feedforward.calculate(
+ wheelSpeeds.rightMetersPerSecond, point.accelerationMetersPerSecondSq)
+ <= maxVoltage + 0.05),
+ () ->
+ assertTrue(
+ feedforward.calculate(
+ wheelSpeeds.rightMetersPerSecond, point.accelerationMetersPerSecondSq)
+ >= -maxVoltage - 0.05));
+ }
+ }
+
+ @Test
+ void testEndpointHighCurvature() {
+ var feedforward = new SimpleMotorFeedforward(1, 1, 3);
+
+ // Large trackwidth - need to test with radius of curvature less than half of trackwidth
+ var kinematics = new DifferentialDriveKinematics(3);
+ double maxVoltage = 10;
+ var constraint = new DifferentialDriveVoltageConstraint(feedforward, kinematics, maxVoltage);
+
+ var config = new TrajectoryConfig(12, 12).addConstraint(constraint);
+
+ // Radius of curvature should be ~1 meter.
+ assertDoesNotThrow(
+ () ->
+ TrajectoryGenerator.generateTrajectory(
+ new Pose2d(1, 0, Rotation2d.fromDegrees(90)),
+ new ArrayList<Translation2d>(),
+ new Pose2d(0, 1, Rotation2d.fromDegrees(180)),
+ config));
+
+ assertDoesNotThrow(
+ () ->
+ TrajectoryGenerator.generateTrajectory(
+ new Pose2d(0, 1, Rotation2d.fromDegrees(180)),
+ new ArrayList<Translation2d>(),
+ new Pose2d(1, 0, Rotation2d.fromDegrees(90)),
+ config.setReversed(true)));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/EllipticalRegionConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/EllipticalRegionConstraintTest.java
new file mode 100644
index 0000000..f9e3c18
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/EllipticalRegionConstraintTest.java
@@ -0,0 +1,80 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.trajectory.constraint.EllipticalRegionConstraint;
+import edu.wpi.first.math.trajectory.constraint.MaxVelocityConstraint;
+import edu.wpi.first.math.util.Units;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+public class EllipticalRegionConstraintTest {
+ @Test
+ void testConstraint() {
+ // Create constraints
+ double maxVelocity = Units.feetToMeters(3.0);
+ var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
+ var regionConstraint =
+ new EllipticalRegionConstraint(
+ new Translation2d(Units.feetToMeters(5.0), Units.feetToMeters(5.0)),
+ Units.feetToMeters(10.0),
+ Units.feetToMeters(5.0),
+ Rotation2d.fromDegrees(180.0),
+ maxVelocityConstraint);
+
+ // Get trajectory
+ var trajectory = TrajectoryGeneratorTest.getTrajectory(List.of(regionConstraint));
+
+ // Iterate through trajectory and check constraints
+ boolean exceededConstraintOutsideRegion = false;
+ for (var point : trajectory.getStates()) {
+ var translation = point.poseMeters.getTranslation();
+
+ if (translation.getX() < Units.feetToMeters(10)
+ && translation.getY() < Units.feetToMeters(5)) {
+ assertTrue(Math.abs(point.velocityMetersPerSecond) < maxVelocity + 0.05);
+ } else if (Math.abs(point.velocityMetersPerSecond) >= maxVelocity + 0.05) {
+ exceededConstraintOutsideRegion = true;
+ }
+ }
+ assertTrue(exceededConstraintOutsideRegion);
+ }
+
+ @Test
+ void testIsPoseWithinRegion() {
+ double maxVelocity = Units.feetToMeters(3.0);
+ var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
+
+ var regionConstraintNoRotation =
+ new EllipticalRegionConstraint(
+ new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
+ Units.feetToMeters(2.0),
+ Units.feetToMeters(4.0),
+ new Rotation2d(),
+ maxVelocityConstraint);
+
+ assertFalse(
+ regionConstraintNoRotation.isPoseInRegion(
+ new Pose2d(Units.feetToMeters(2.1), Units.feetToMeters(1.0), new Rotation2d())));
+
+ var regionConstraintWithRotation =
+ new EllipticalRegionConstraint(
+ new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
+ Units.feetToMeters(2.0),
+ Units.feetToMeters(4.0),
+ Rotation2d.fromDegrees(90.0),
+ maxVelocityConstraint);
+
+ assertTrue(
+ regionConstraintWithRotation.isPoseInRegion(
+ new Pose2d(Units.feetToMeters(2.1), Units.feetToMeters(1.0), new Rotation2d())));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/RectangularRegionConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/RectangularRegionConstraintTest.java
new file mode 100644
index 0000000..1ab826e
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/RectangularRegionConstraintTest.java
@@ -0,0 +1,61 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.trajectory.constraint.MaxVelocityConstraint;
+import edu.wpi.first.math.trajectory.constraint.RectangularRegionConstraint;
+import edu.wpi.first.math.util.Units;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+public class RectangularRegionConstraintTest {
+ @Test
+ void testConstraint() {
+ // Create constraints
+ double maxVelocity = Units.feetToMeters(3.0);
+ var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
+ var regionConstraint =
+ new RectangularRegionConstraint(
+ new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
+ new Translation2d(Units.feetToMeters(7.0), Units.feetToMeters(27.0)),
+ maxVelocityConstraint);
+
+ // Get trajectory
+ var trajectory = TrajectoryGeneratorTest.getTrajectory(List.of(regionConstraint));
+
+ // Iterate through trajectory and check constraints
+ boolean exceededConstraintOutsideRegion = false;
+ for (var point : trajectory.getStates()) {
+ if (regionConstraint.isPoseInRegion(point.poseMeters)) {
+ assertTrue(Math.abs(point.velocityMetersPerSecond) < maxVelocity + 0.05);
+ } else if (Math.abs(point.velocityMetersPerSecond) >= maxVelocity + 0.05) {
+ exceededConstraintOutsideRegion = true;
+ }
+ }
+ assertTrue(exceededConstraintOutsideRegion);
+ }
+
+ @Test
+ void testIsPoseWithinRegion() {
+ double maxVelocity = Units.feetToMeters(3.0);
+ var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
+ var regionConstraint =
+ new RectangularRegionConstraint(
+ new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
+ new Translation2d(Units.feetToMeters(7.0), Units.feetToMeters(27.0)),
+ maxVelocityConstraint);
+
+ assertFalse(regionConstraint.isPoseInRegion(new Pose2d()));
+ assertTrue(
+ regionConstraint.isPoseInRegion(
+ new Pose2d(Units.feetToMeters(3.0), Units.feetToMeters(14.5), new Rotation2d())));
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryConcatenateTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryConcatenateTest.java
new file mode 100644
index 0000000..2e80d7b
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryConcatenateTest.java
@@ -0,0 +1,52 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class TrajectoryConcatenateTest {
+ @Test
+ void testStates() {
+ var t1 =
+ TrajectoryGenerator.generateTrajectory(
+ new Pose2d(),
+ List.of(),
+ new Pose2d(1, 1, new Rotation2d()),
+ new TrajectoryConfig(2, 2));
+
+ var t2 =
+ TrajectoryGenerator.generateTrajectory(
+ new Pose2d(1, 1, new Rotation2d()),
+ List.of(),
+ new Pose2d(2, 2, Rotation2d.fromDegrees(45)),
+ new TrajectoryConfig(2, 2));
+
+ var t = t1.concatenate(t2);
+
+ double time = -1.0;
+ for (int i = 0; i < t.getStates().size(); ++i) {
+ var state = t.getStates().get(i);
+
+ // Make sure that the timestamps are strictly increasing.
+ assertTrue(state.timeSeconds > time);
+ time = state.timeSeconds;
+
+ // Ensure that the states in t are the same as those in t1 and t2.
+ if (i < t1.getStates().size()) {
+ assertEquals(state, t1.getStates().get(i));
+ } else {
+ var st = t2.getStates().get(i - t1.getStates().size() + 1);
+ st.timeSeconds += t1.getTotalTimeSeconds();
+ assertEquals(state, st);
+ }
+ }
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryGeneratorTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryGeneratorTest.java
new file mode 100644
index 0000000..97c1858
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryGeneratorTest.java
@@ -0,0 +1,84 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static edu.wpi.first.math.util.Units.feetToMeters;
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Transform2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import edu.wpi.first.math.trajectory.constraint.TrajectoryConstraint;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class TrajectoryGeneratorTest {
+ static Trajectory getTrajectory(List<? extends TrajectoryConstraint> constraints) {
+ final double maxVelocity = feetToMeters(12.0);
+ final double maxAccel = feetToMeters(12);
+
+ // 2018 cross scale auto waypoints.
+ var sideStart =
+ new Pose2d(feetToMeters(1.54), feetToMeters(23.23), Rotation2d.fromDegrees(-180));
+ var crossScale =
+ new Pose2d(feetToMeters(23.7), feetToMeters(6.8), Rotation2d.fromDegrees(-160));
+
+ var waypoints = new ArrayList<Pose2d>();
+ waypoints.add(sideStart);
+ waypoints.add(
+ sideStart.plus(
+ new Transform2d(
+ new Translation2d(feetToMeters(-13), feetToMeters(0)), new Rotation2d())));
+ waypoints.add(
+ sideStart.plus(
+ new Transform2d(
+ new Translation2d(feetToMeters(-19.5), feetToMeters(5)),
+ Rotation2d.fromDegrees(-90))));
+ waypoints.add(crossScale);
+
+ TrajectoryConfig config =
+ new TrajectoryConfig(maxVelocity, maxAccel).setReversed(true).addConstraints(constraints);
+
+ return TrajectoryGenerator.generateTrajectory(waypoints, config);
+ }
+
+ @Test
+ @SuppressWarnings("LocalVariableName")
+ void testGenerationAndConstraints() {
+ Trajectory trajectory = getTrajectory(new ArrayList<>());
+
+ double duration = trajectory.getTotalTimeSeconds();
+ double t = 0.0;
+ double dt = 0.02;
+
+ while (t < duration) {
+ var point = trajectory.sample(t);
+ t += dt;
+ assertAll(
+ () -> assertTrue(Math.abs(point.velocityMetersPerSecond) < feetToMeters(12.0) + 0.05),
+ () ->
+ assertTrue(
+ Math.abs(point.accelerationMetersPerSecondSq) < feetToMeters(12.0) + 0.05));
+ }
+ }
+
+ @Test
+ void testMalformedTrajectory() {
+ var traj =
+ TrajectoryGenerator.generateTrajectory(
+ Arrays.asList(
+ new Pose2d(0, 0, Rotation2d.fromDegrees(0)),
+ new Pose2d(1, 0, Rotation2d.fromDegrees(180))),
+ new TrajectoryConfig(feetToMeters(12), feetToMeters(12)));
+
+ assertEquals(traj.getStates().size(), 1);
+ assertEquals(traj.getTotalTimeSeconds(), 0);
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryJsonTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryJsonTest.java
new file mode 100644
index 0000000..bb72601
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryJsonTest.java
@@ -0,0 +1,30 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.kinematics.DifferentialDriveKinematics;
+import edu.wpi.first.math.trajectory.constraint.DifferentialDriveKinematicsConstraint;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+public class TrajectoryJsonTest {
+ @Test
+ void deserializeMatches() {
+ var config =
+ List.of(new DifferentialDriveKinematicsConstraint(new DifferentialDriveKinematics(20), 3));
+ var trajectory = TrajectoryGeneratorTest.getTrajectory(config);
+
+ var deserialized =
+ assertDoesNotThrow(
+ () ->
+ TrajectoryUtil.deserializeTrajectory(
+ TrajectoryUtil.serializeTrajectory(trajectory)));
+
+ assertEquals(trajectory.getStates(), deserialized.getStates());
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryTransformTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryTransformTest.java
new file mode 100644
index 0000000..6268768
--- /dev/null
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrajectoryTransformTest.java
@@ -0,0 +1,64 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package edu.wpi.first.math.trajectory;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import edu.wpi.first.math.geometry.Pose2d;
+import edu.wpi.first.math.geometry.Rotation2d;
+import edu.wpi.first.math.geometry.Transform2d;
+import edu.wpi.first.math.geometry.Translation2d;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class TrajectoryTransformTest {
+ @Test
+ void testTransformBy() {
+ var config = new TrajectoryConfig(3, 3);
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(
+ new Pose2d(), List.of(), new Pose2d(1, 1, Rotation2d.fromDegrees(90)), config);
+
+ var transformedTrajectory =
+ trajectory.transformBy(
+ new Transform2d(new Translation2d(1, 2), Rotation2d.fromDegrees(30)));
+
+ // Test initial pose.
+ assertEquals(
+ new Pose2d(1, 2, Rotation2d.fromDegrees(30)), transformedTrajectory.sample(0).poseMeters);
+
+ testSameShapedTrajectory(trajectory.getStates(), transformedTrajectory.getStates());
+ }
+
+ @Test
+ void testRelativeTo() {
+ var config = new TrajectoryConfig(3, 3);
+ var trajectory =
+ TrajectoryGenerator.generateTrajectory(
+ new Pose2d(1, 2, Rotation2d.fromDegrees(30.0)),
+ List.of(),
+ new Pose2d(5, 7, Rotation2d.fromDegrees(90)),
+ config);
+
+ var transformedTrajectory = trajectory.relativeTo(new Pose2d(1, 2, Rotation2d.fromDegrees(30)));
+
+ // Test initial pose.
+ assertEquals(new Pose2d(), transformedTrajectory.sample(0).poseMeters);
+
+ testSameShapedTrajectory(trajectory.getStates(), transformedTrajectory.getStates());
+ }
+
+ void testSameShapedTrajectory(List<Trajectory.State> statesA, List<Trajectory.State> statesB) {
+ for (int i = 0; i < statesA.size() - 1; i++) {
+ var a1 = statesA.get(i).poseMeters;
+ var a2 = statesA.get(i + 1).poseMeters;
+
+ var b1 = statesB.get(i).poseMeters;
+ var b2 = statesB.get(i + 1).poseMeters;
+
+ assertEquals(a2.relativeTo(a1), b2.relativeTo(b1));
+ }
+ }
+}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrapezoidProfileTest.java b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrapezoidProfileTest.java
similarity index 90%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrapezoidProfileTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/trajectory/TrapezoidProfileTest.java
index e155188..062563f 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrapezoidProfileTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/trajectory/TrapezoidProfileTest.java
@@ -1,19 +1,15 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.trajectory;
-
-import org.junit.jupiter.api.Test;
+package edu.wpi.first.math.trajectory;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
-@SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
+import org.junit.jupiter.api.Test;
+
class TrapezoidProfileTest {
private static final double kDt = 0.01;
@@ -35,8 +31,9 @@
* @param eps Tolerance for whether values are near to each other.
*/
private static void assertNear(double val1, double val2, double eps) {
- assertTrue(Math.abs(val1 - val2) <= eps, "Difference between " + val1 + " and " + val2
- + " is greater than " + eps);
+ assertTrue(
+ Math.abs(val1 - val2) <= eps,
+ "Difference between " + val1 + " and " + val2 + " is greater than " + eps);
}
/**
@@ -56,8 +53,7 @@
@Test
void reachesGoal() {
- TrapezoidProfile.Constraints constraints =
- new TrapezoidProfile.Constraints(1.75, 0.75);
+ TrapezoidProfile.Constraints constraints = new TrapezoidProfile.Constraints(1.75, 0.75);
TrapezoidProfile.State goal = new TrapezoidProfile.State(3, 0);
TrapezoidProfile.State state = new TrapezoidProfile.State();
@@ -81,7 +77,7 @@
double lastPos = state.position;
for (int i = 0; i < 1600; ++i) {
if (i == 400) {
- constraints.maxVelocity = 0.75;
+ constraints = new TrapezoidProfile.Constraints(0.75, 0.75);
}
profile = new TrapezoidProfile(constraints, goal, state);
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/util/UnitsTest.java b/wpimath/src/test/java/edu/wpi/first/math/util/UnitsTest.java
similarity index 63%
rename from wpimath/src/test/java/edu/wpi/first/wpilibj/util/UnitsTest.java
rename to wpimath/src/test/java/edu/wpi/first/math/util/UnitsTest.java
index 9938660..99a8b5a 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/util/UnitsTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/util/UnitsTest.java
@@ -1,18 +1,14 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-package edu.wpi.first.wpilibj.util;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.UtilityClassTest;
+package edu.wpi.first.math.util;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import edu.wpi.first.wpilibj.UtilityClassTest;
+import org.junit.jupiter.api.Test;
+
class UnitsTest extends UtilityClassTest<Units> {
UnitsTest() {
super(Units.class);
@@ -57,4 +53,23 @@
void radiansPerSecondToRotationsPerMinute() {
assertEquals(76.39, Units.radiansPerSecondToRotationsPerMinute(8), 1e-2);
}
+
+ @Test
+ void millisecondsToSeconds() {
+ assertEquals(0.5, Units.millisecondsToSeconds(500), 1e-2);
+ }
+
+ @Test
+ void secondsToMilliseconds() {
+ assertEquals(1500, Units.secondsToMilliseconds(1.5), 1e-2);
+ }
+
+ void kilogramsToLbsTest() {
+ assertEquals(2.20462, Units.kilogramsToLbs(1), 1e-2);
+ }
+
+ @Test
+ void lbsToKilogramsTest() {
+ assertEquals(0.453592, Units.lbsToKilograms(1), 1e-2);
+ }
}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/LinearFilterTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/LinearFilterTest.java
deleted file mode 100644
index da58b29..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/LinearFilterTest.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2015-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj;
-
-import java.util.Random;
-import java.util.function.DoubleFunction;
-import java.util.stream.Stream;
-
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.Arguments;
-import org.junit.jupiter.params.provider.MethodSource;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.params.provider.Arguments.arguments;
-
-class LinearFilterTest {
- private static final double kFilterStep = 0.005;
- private static final double kFilterTime = 2.0;
- private static final double kSinglePoleIIRTimeConstant = 0.015915;
- private static final double kHighPassTimeConstant = 0.006631;
- private static final int kMovAvgTaps = 6;
-
- private static final double kSinglePoleIIRExpectedOutput = -3.2172003;
- private static final double kHighPassExpectedOutput = 10.074717;
- private static final double kMovAvgExpectedOutput = -10.191644;
-
- @SuppressWarnings("ParameterName")
- private static double getData(double t) {
- return 100.0 * Math.sin(2.0 * Math.PI * t) + 20.0 * Math.cos(50.0 * Math.PI * t);
- }
-
- @SuppressWarnings("ParameterName")
- private static double getPulseData(double t) {
- if (Math.abs(t - 1.0) < 0.001) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
-
- @Test
- void illegalTapNumberTest() {
- assertThrows(IllegalArgumentException.class, () -> LinearFilter.movingAverage(0));
- }
-
- /**
- * Test if the filter reduces the noise produced by a signal generator.
- */
- @ParameterizedTest
- @MethodSource("noiseFilterProvider")
- void noiseReduceTest(final LinearFilter filter) {
- double noiseGenError = 0.0;
- double filterError = 0.0;
-
- final Random gen = new Random();
- final double kStdDev = 10.0;
-
- for (double t = 0; t < kFilterTime; t += kFilterStep) {
- final double theory = getData(t);
- final double noise = gen.nextGaussian() * kStdDev;
- filterError += Math.abs(filter.calculate(theory + noise) - theory);
- noiseGenError += Math.abs(noise - theory);
- }
-
- assertTrue(noiseGenError > filterError,
- "Filter should have reduced noise accumulation from " + noiseGenError
- + " but failed. The filter error was " + filterError);
- }
-
- static Stream<LinearFilter> noiseFilterProvider() {
- return Stream.of(
- LinearFilter.singlePoleIIR(kSinglePoleIIRTimeConstant, kFilterStep),
- LinearFilter.movingAverage(kMovAvgTaps)
- );
- }
-
- /**
- * Test if the linear filters produce consistent output for a given data set.
- */
- @ParameterizedTest
- @MethodSource("outputFilterProvider")
- void outputTest(final LinearFilter filter, final DoubleFunction<Double> data,
- final double expectedOutput) {
- double filterOutput = 0.0;
- for (double t = 0.0; t < kFilterTime; t += kFilterStep) {
- filterOutput = filter.calculate(data.apply(t));
- }
-
- assertEquals(expectedOutput, filterOutput, 5e-5, "Filter output was incorrect.");
- }
-
- static Stream<Arguments> outputFilterProvider() {
- return Stream.of(
- arguments(LinearFilter.singlePoleIIR(kSinglePoleIIRTimeConstant, kFilterStep),
- (DoubleFunction<Double>) LinearFilterTest::getData,
- kSinglePoleIIRExpectedOutput),
- arguments(LinearFilter.highPass(kHighPassTimeConstant, kFilterStep),
- (DoubleFunction<Double>) LinearFilterTest::getData,
- kHighPassExpectedOutput),
- arguments(LinearFilter.movingAverage(kMovAvgTaps),
- (DoubleFunction<Double>) LinearFilterTest::getData,
- kMovAvgExpectedOutput),
- arguments(LinearFilter.movingAverage(kMovAvgTaps),
- (DoubleFunction<Double>) LinearFilterTest::getPulseData,
- 0.0)
- );
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/UtilityClassTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/UtilityClassTest.java
index 05c5786..8eed93f 100644
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/UtilityClassTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/wpilibj/UtilityClassTest.java
@@ -1,28 +1,30 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2018-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.wpilibj;
-import java.lang.reflect.Constructor;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Modifier;
-import java.util.Arrays;
-import java.util.stream.Stream;
-
-import org.junit.jupiter.api.DynamicTest;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.TestFactory;
-
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.DynamicTest.dynamicTest;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Modifier;
+import java.util.Arrays;
+import java.util.stream.Stream;
+import org.junit.jupiter.api.DynamicTest;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestFactory;
+
+// Declaring this class abstract prevents UtilityClassTest from running on itself and throwing the
+// following exception:
+//
+// org.junit.jupiter.api.extension.ParameterResolutionException: No ParameterResolver registered
+// for parameter [java.lang.Class<T> arg0] in constructor [protected
+// edu.wpi.first.wpilibj.UtilityClassTest(java.lang.Class<T>)].
@SuppressWarnings("PMD.AbstractClassWithoutAbstractMethod")
public abstract class UtilityClassTest<T> {
private final Class<T> m_clazz;
@@ -33,8 +35,7 @@
@Test
public void singleConstructorTest() {
- assertEquals(1, m_clazz.getDeclaredConstructors().length,
- "More than one constructor defined");
+ assertEquals(1, m_clazz.getDeclaredConstructors().length, "More than one constructor defined");
}
@Test
@@ -55,7 +56,9 @@
Stream<DynamicTest> publicMethodsStaticTestFactory() {
return Arrays.stream(m_clazz.getDeclaredMethods())
.filter(method -> Modifier.isPublic(method.getModifiers()))
- .map(method -> dynamicTest(method.getName(),
- () -> assertTrue(Modifier.isStatic(method.getModifiers()))));
+ .map(
+ method ->
+ dynamicTest(
+ method.getName(), () -> assertTrue(Modifier.isStatic(method.getModifiers()))));
}
}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/ControlAffinePlantInversionFeedforwardTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/ControlAffinePlantInversionFeedforwardTest.java
deleted file mode 100644
index 75a29e3..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/ControlAffinePlantInversionFeedforwardTest.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class ControlAffinePlantInversionFeedforwardTest {
- @SuppressWarnings("LocalVariableName")
- @Test
- void testCalculate() {
- ControlAffinePlantInversionFeedforward<N2, N1> feedforward =
- new ControlAffinePlantInversionFeedforward<N2, N1>(
- Nat.N2(),
- Nat.N1(),
- this::getDynamics,
- 0.02);
-
- assertEquals(48.0, feedforward.calculate(
- VecBuilder.fill(2, 2),
- VecBuilder.fill(3, 3)).get(0, 0),
- 1e-6);
- }
-
- @SuppressWarnings("LocalVariableName")
- @Test
- void testCalculateState() {
- ControlAffinePlantInversionFeedforward<N2, N1> feedforward =
- new ControlAffinePlantInversionFeedforward<N2, N1>(
- Nat.N2(),
- Nat.N1(),
- this::getDynamics,
- 0.02);
-
- assertEquals(48.0, feedforward.calculate(
- VecBuilder.fill(2, 2),
- VecBuilder.fill(3, 3)).get(0, 0),
- 1e-6);
- }
-
- @SuppressWarnings("ParameterName")
- protected Matrix<N2, N1> getDynamics(Matrix<N2, N1> x, Matrix<N1, N1> u) {
- var result = new Matrix<>(Nat.N2(), Nat.N1());
-
- result = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.000, 0, 0, 1.000).times(x)
- .plus(VecBuilder.fill(0, 1).times(u));
-
- return result;
- }
-
- @SuppressWarnings("ParameterName")
- protected Matrix<N2, N1> getStateDynamics(Matrix<N2, N1> x) {
- var result = new Matrix<>(Nat.N2(), Nat.N1());
-
- result = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.000, 0, 0, 1.000).times(x);
-
- return result;
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearPlantInversionFeedforwardTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearPlantInversionFeedforwardTest.java
deleted file mode 100644
index 8a09383..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearPlantInversionFeedforwardTest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class LinearPlantInversionFeedforwardTest {
- @SuppressWarnings("LocalVariableName")
- @Test
- void testCalculate() {
- Matrix<N2, N2> A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1);
- Matrix<N2, N1> B = VecBuilder.fill(0, 1);
-
- LinearPlantInversionFeedforward<N2, N1, N1> feedforward =
- new LinearPlantInversionFeedforward<N2, N1, N1>(A, B, 0.02);
-
- assertEquals(47.502599, feedforward.calculate(
- VecBuilder.fill(2, 2),
- VecBuilder.fill(3, 3)).get(0, 0),
- 0.002);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearQuadraticRegulatorTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearQuadraticRegulatorTest.java
deleted file mode 100644
index e047198..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearQuadraticRegulatorTest.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpilibj.system.plant.DCMotor;
-import edu.wpi.first.wpilibj.system.plant.LinearSystemId;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-public class LinearQuadraticRegulatorTest {
- public static LinearSystem<N2, N1, N1> elevatorPlant;
- static LinearSystem<N2, N1, N1> armPlant;
-
- static {
- createArm();
- createElevator();
- }
-
- @SuppressWarnings("LocalVariableName")
- public static void createArm() {
- var motors = DCMotor.getVex775Pro(2);
-
- var m = 4.0;
- var r = 0.4;
- var J = 1d / 3d * m * r * r;
- var G = 100.0;
-
- armPlant = LinearSystemId.createSingleJointedArmSystem(motors, J, G);
- }
-
- @SuppressWarnings("LocalVariableName")
- public static void createElevator() {
- var motors = DCMotor.getVex775Pro(2);
-
- var m = 5.0;
- var r = 0.0181864;
- var G = 1.0;
- elevatorPlant = LinearSystemId.createElevatorSystem(motors, m, r, G);
- }
-
- @Test
- @SuppressWarnings("LocalVariableName")
- public void testLQROnElevator() {
-
- var qElms = VecBuilder.fill(0.02, 0.4);
- var rElms = VecBuilder.fill(12.0);
- var dt = 0.00505;
-
- var controller = new LinearQuadraticRegulator<>(
- elevatorPlant, qElms, rElms, dt);
-
- var k = controller.getK();
-
- assertEquals(522.153, k.get(0, 0), 0.1);
- assertEquals(38.2, k.get(0, 1), 0.1);
- }
-
- @Test
- public void testFourMotorElevator() {
-
- var dt = 0.020;
-
- var plant = LinearSystemId.createElevatorSystem(
- DCMotor.getVex775Pro(4),
- 8.0,
- 0.75 * 25.4 / 1000.0,
- 14.67);
-
- var regulator = new LinearQuadraticRegulator<>(
- plant,
- VecBuilder.fill(0.1, 0.2),
- VecBuilder.fill(12.0),
- dt);
-
- assertEquals(10.381, regulator.getK().get(0, 0), 1e-2);
- assertEquals(0.6929, regulator.getK().get(0, 1), 1e-2);
-
- }
-
- @Test
- @SuppressWarnings("LocalVariableName")
- public void testLQROnArm() {
-
- var motors = DCMotor.getVex775Pro(2);
-
- var m = 4.0;
- var r = 0.4;
- var G = 100.0;
-
- var plant = LinearSystemId.createSingleJointedArmSystem(motors, 1d / 3d * m * r * r, G);
-
- var qElms = VecBuilder.fill(0.01745, 0.08726);
- var rElms = VecBuilder.fill(12.0);
- var dt = 0.00505;
-
- var controller = new LinearQuadraticRegulator<>(
- plant, qElms, rElms, dt);
-
- var k = controller.getK();
-
- assertEquals(19.16, k.get(0, 0), 0.1);
- assertEquals(3.32, k.get(0, 1), 0.1);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearSystemLoopTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearSystemLoopTest.java
deleted file mode 100644
index 14a0a9d..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/controller/LinearSystemLoopTest.java
+++ /dev/null
@@ -1,169 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.controller;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.estimator.KalmanFilter;
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpilibj.system.LinearSystemLoop;
-import edu.wpi.first.wpilibj.system.plant.DCMotor;
-import edu.wpi.first.wpilibj.system.plant.LinearSystemId;
-import edu.wpi.first.wpilibj.trajectory.TrapezoidProfile;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class LinearSystemLoopTest {
- public static final double kDt = 0.00505;
- private static final double kPositionStddev = 0.0001;
- private static final Random random = new Random();
- private final LinearSystemLoop<N2, N1, N1> m_loop;
-
- @SuppressWarnings("LocalVariableName")
- public LinearSystemLoopTest() {
- LinearSystem<N2, N1, N1> plant = LinearSystemId.createElevatorSystem(DCMotor.getVex775Pro(2), 5,
- 0.0181864, 1.0);
- KalmanFilter<N2, N1, N1> observer = new KalmanFilter<>(Nat.N2(), Nat.N1(), plant,
- VecBuilder.fill(0.05, 1.0),
- VecBuilder.fill(0.0001), kDt);
-
- var qElms = VecBuilder.fill(0.02, 0.4);
- var rElms = VecBuilder.fill(12.0);
- var dt = 0.00505;
-
- var controller = new LinearQuadraticRegulator<>(
- plant, qElms, rElms, dt);
-
- m_loop = new LinearSystemLoop<>(plant, controller, observer, 12, dt);
- }
-
- @SuppressWarnings("LocalVariableName")
- private static void updateTwoState(LinearSystemLoop<N2, N1, N1> loop, double noise) {
- Matrix<N1, N1> y = loop.getPlant().calculateY(loop.getXHat(), loop.getU()).plus(
- VecBuilder.fill(noise)
- );
-
- loop.correct(y);
- loop.predict(kDt);
- }
-
- @Test
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
- public void testStateSpaceEnabled() {
-
- m_loop.reset(VecBuilder.fill(0, 0));
- Matrix<N2, N1> references = VecBuilder.fill(2.0, 0.0);
- var constraints = new TrapezoidProfile.Constraints(4, 3);
- m_loop.setNextR(references);
-
- TrapezoidProfile profile;
- TrapezoidProfile.State state;
- for (int i = 0; i < 1000; i++) {
- profile = new TrapezoidProfile(
- constraints, new TrapezoidProfile.State(m_loop.getXHat(0), m_loop.getXHat(1)),
- new TrapezoidProfile.State(references.get(0, 0), references.get(1, 0))
- );
- state = profile.calculate(kDt);
- m_loop.setNextR(VecBuilder.fill(state.position, state.velocity));
-
- updateTwoState(m_loop, (random.nextGaussian()) * kPositionStddev);
- var u = m_loop.getU(0);
-
- assertTrue(u >= -12.1 && u <= 12.1, "U out of bounds! Got " + u);
- }
-
- assertEquals(2.0, m_loop.getXHat(0), 0.05);
- assertEquals(0.0, m_loop.getXHat(1), 0.5);
-
- }
-
- @Test
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
- public void testFlywheelEnabled() {
-
- LinearSystem<N1, N1, N1> plant = LinearSystemId.createFlywheelSystem(DCMotor.getNEO(2),
- 0.00289, 1.0);
- KalmanFilter<N1, N1, N1> observer = new KalmanFilter<>(Nat.N1(), Nat.N1(), plant,
- VecBuilder.fill(1.0),
- VecBuilder.fill(kPositionStddev), kDt);
-
- var qElms = VecBuilder.fill(9.0);
- var rElms = VecBuilder.fill(12.0);
-
- var controller = new LinearQuadraticRegulator<>(
- plant, qElms, rElms, kDt);
-
- var feedforward = new LinearPlantInversionFeedforward<>(plant, kDt);
-
- var loop = new LinearSystemLoop<>(plant, controller, feedforward, observer, 12);
-
- loop.reset(VecBuilder.fill(0.0));
- var references = VecBuilder.fill(3000 / 60d * 2 * Math.PI);
-
- loop.setNextR(references);
-
- List<Double> timeData = new ArrayList<>();
- List<Double> xHat = new ArrayList<>();
- List<Double> volt = new ArrayList<>();
- List<Double> reference = new ArrayList<>();
- List<Double> error = new ArrayList<>();
-
- var time = 0.0;
- while (time < 10) {
-
- if (time > 10) {
- break;
- }
-
- loop.setNextR(references);
-
- Matrix<N1, N1> y = loop.getPlant().calculateY(loop.getXHat(), loop.getU()).plus(
- VecBuilder.fill(random.nextGaussian() * kPositionStddev)
- );
-
- loop.correct(y);
- loop.predict(kDt);
- var u = loop.getU(0);
-
- assertTrue(u >= -12.1 && u <= 12.1, "U out of bounds! Got " + u);
-
- xHat.add(loop.getXHat(0) / 2d / Math.PI * 60);
- volt.add(u);
- timeData.add(time);
- reference.add(references.get(0, 0) / 2d / Math.PI * 60);
- error.add(loop.getError(0) / 2d / Math.PI * 60);
- time += kDt;
- }
-
- // XYChart bigChart = new XYChartBuilder().build();
- // bigChart.addSeries("Xhat, RPM", timeData, xHat);
- // bigChart.addSeries("Reference, RPM", timeData, reference);
- // bigChart.addSeries("Error, RPM", timeData, error);
-
- // XYChart smolChart = new XYChartBuilder().build();
- // smolChart.addSeries("Volts, V", timeData, volt);
-
- // try {
- // new SwingWrapper<>(List.of(bigChart, smolChart)).displayChartMatrix();
- // Thread.sleep(10000000);
- // } catch (InterruptedException e) { e.printStackTrace(); }
-
- assertEquals(0.0, loop.getError(0), 0.1);
- }
-
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilterTest.java
deleted file mode 100644
index 6c7cc92..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilterTest.java
+++ /dev/null
@@ -1,219 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpilibj.system.NumericalJacobian;
-import edu.wpi.first.wpilibj.system.RungeKutta;
-import edu.wpi.first.wpilibj.system.plant.DCMotor;
-import edu.wpi.first.wpilibj.trajectory.TrajectoryConfig;
-import edu.wpi.first.wpilibj.trajectory.TrajectoryGenerator;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-import edu.wpi.first.wpiutil.math.numbers.N3;
-import edu.wpi.first.wpiutil.math.numbers.N5;
-
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-@SuppressWarnings("ParameterName")
-public class ExtendedKalmanFilterTest {
- public static Matrix<N5, N1> getDynamics(final Matrix<N5, N1> x, final Matrix<N2, N1> u) {
- final var motors = DCMotor.getCIM(2);
-
- final var gr = 7.08; // Gear ratio
- final var rb = 0.8382 / 2.0; // Wheelbase radius (track width)
- final var r = 0.0746125; // Wheel radius
- final var m = 63.503; // Robot mass
- final var J = 5.6; // Robot moment of inertia
-
- final var C1 =
- -Math.pow(gr, 2) * motors.m_KtNMPerAmp / (
- motors.m_KvRadPerSecPerVolt * motors.m_rOhms * r * r);
- final var C2 = gr * motors.m_KtNMPerAmp / (motors.m_rOhms * r);
- final var k1 = 1.0 / m + rb * rb / J;
- final var k2 = 1.0 / m - rb * rb / J;
-
- final var vl = x.get(3, 0);
- final var vr = x.get(4, 0);
- final var Vl = u.get(0, 0);
- final var Vr = u.get(1, 0);
-
- final Matrix<N5, N1> result = new Matrix<>(Nat.N5(), Nat.N1());
- final var v = 0.5 * (vl + vr);
- result.set(0, 0, v * Math.cos(x.get(2, 0)));
- result.set(1, 0, v * Math.sin(x.get(2, 0)));
- result.set(2, 0, (vr - vl) / (2.0 * rb));
- result.set(3, 0, k1 * ((C1 * vl) + (C2 * Vl)) + k2 * ((C1 * vr) + (C2 * Vr)));
- result.set(4, 0, k2 * ((C1 * vl) + (C2 * Vl)) + k1 * ((C1 * vr) + (C2 * Vr)));
- return result;
- }
-
- public static Matrix<N3, N1> getLocalMeasurementModel(Matrix<N5, N1> x, Matrix<N2, N1> u) {
- return VecBuilder.fill(x.get(2, 0), x.get(3, 0), x.get(4, 0));
- }
-
- public static Matrix<N5, N1> getGlobalMeasurementModel(Matrix<N5, N1> x, Matrix<N2, N1> u) {
- return VecBuilder.fill(
- x.get(0, 0),
- x.get(1, 0),
- x.get(2, 0),
- x.get(3, 0),
- x.get(4, 0)
- );
- }
-
- @SuppressWarnings("LocalVariableName")
- @Test
- public void testInit() {
- double dtSeconds = 0.00505;
-
- assertDoesNotThrow(() -> {
- ExtendedKalmanFilter<N5, N2, N3> observer =
- new ExtendedKalmanFilter<>(Nat.N5(), Nat.N2(), Nat.N3(),
- ExtendedKalmanFilterTest::getDynamics,
- ExtendedKalmanFilterTest::getLocalMeasurementModel,
- VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0),
- VecBuilder.fill(0.0001, 0.01, 0.01), dtSeconds);
-
- Matrix<N2, N1> u = VecBuilder.fill(12.0, 12.0);
- observer.predict(u, dtSeconds);
-
- var localY = getLocalMeasurementModel(observer.getXhat(), u);
- observer.correct(u, localY);
-
- var globalY = getGlobalMeasurementModel(observer.getXhat(), u);
- var R = StateSpaceUtil.makeCostMatrix(
- VecBuilder.fill(0.01, 0.01, 0.0001, 0.5, 0.5));
- observer.correct(Nat.N5(),
- u, globalY, ExtendedKalmanFilterTest::getGlobalMeasurementModel, R);
- });
-
- }
-
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops",
- "PMD.ExcessiveMethodLength"})
- @Test
- public void testConvergence() {
- double dtSeconds = 0.00505;
- double rbMeters = 0.8382 / 2.0; // Robot radius
-
- ExtendedKalmanFilter<N5, N2, N3> observer =
- new ExtendedKalmanFilter<>(Nat.N5(), Nat.N2(), Nat.N3(),
- ExtendedKalmanFilterTest::getDynamics,
- ExtendedKalmanFilterTest::getLocalMeasurementModel,
- VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0),
- VecBuilder.fill(0.001, 0.01, 0.01), dtSeconds);
-
- List<Pose2d> waypoints = Arrays.asList(new Pose2d(2.75, 22.521, new Rotation2d()),
- new Pose2d(24.73, 19.68, Rotation2d.fromDegrees(5.846)));
- var trajectory = TrajectoryGenerator.generateTrajectory(
- waypoints,
- new TrajectoryConfig(8.8, 0.1)
- );
-
- Matrix<N5, N1> r = new Matrix<>(Nat.N5(), Nat.N1());
-
- Matrix<N5, N1> nextR = new Matrix<>(Nat.N5(), Nat.N1());
- Matrix<N2, N1> u = new Matrix<>(Nat.N2(), Nat.N1());
-
- List<Double> trajXs = new ArrayList<>();
- List<Double> trajYs = new ArrayList<>();
-
- List<Double> observerXs = new ArrayList<>();
- List<Double> observerYs = new ArrayList<>();
-
- var B = NumericalJacobian.numericalJacobianU(Nat.N5(), Nat.N2(),
- ExtendedKalmanFilterTest::getDynamics, new Matrix<>(Nat.N5(), Nat.N1()), u);
-
- observer.setXhat(VecBuilder.fill(
- trajectory.getInitialPose().getTranslation().getX(),
- trajectory.getInitialPose().getTranslation().getY(),
- trajectory.getInitialPose().getRotation().getRadians(),
- 0.0, 0.0));
-
- var groundTruthX = observer.getXhat();
-
- double totalTime = trajectory.getTotalTimeSeconds();
- for (int i = 0; i < (totalTime / dtSeconds); i++) {
- var ref = trajectory.sample(dtSeconds * i);
- double vl = ref.velocityMetersPerSecond * (1 - (ref.curvatureRadPerMeter * rbMeters));
- double vr = ref.velocityMetersPerSecond * (1 + (ref.curvatureRadPerMeter * rbMeters));
-
- nextR.set(0, 0, ref.poseMeters.getTranslation().getX());
- nextR.set(1, 0, ref.poseMeters.getTranslation().getY());
- nextR.set(2, 0, ref.poseMeters.getRotation().getRadians());
- nextR.set(3, 0, vl);
- nextR.set(4, 0, vr);
-
- var localY =
- getLocalMeasurementModel(groundTruthX, u);
- var whiteNoiseStdDevs = VecBuilder.fill(0.0001, 0.5, 0.5);
- observer.correct(u,
- localY.plus(StateSpaceUtil.makeWhiteNoiseVector(whiteNoiseStdDevs)));
-
- Matrix<N5, N1> rdot = nextR.minus(r).div(dtSeconds);
- u = new Matrix<>(B.solve(rdot.minus(getDynamics(r, new Matrix<>(Nat.N2(), Nat.N1())))));
-
- observer.predict(u, dtSeconds);
-
- groundTruthX = RungeKutta.rungeKutta(ExtendedKalmanFilterTest::getDynamics,
- groundTruthX, u, dtSeconds);
-
- r = nextR;
-
- trajXs.add(ref.poseMeters.getTranslation().getX());
- trajYs.add(ref.poseMeters.getTranslation().getY());
-
- observerXs.add(observer.getXhat().get(0, 0));
- observerYs.add(observer.getXhat().get(1, 0));
- }
-
- var localY = getLocalMeasurementModel(observer.getXhat(), u);
- observer.correct(u, localY);
-
- var globalY = getGlobalMeasurementModel(observer.getXhat(), u);
- var R = StateSpaceUtil.makeCostMatrix(
- VecBuilder.fill(0.01, 0.01, 0.0001, 0.5, 0.5));
- observer.correct(Nat.N5(), u, globalY, ExtendedKalmanFilterTest::getGlobalMeasurementModel, R);
-
- // var chartBuilder = new XYChartBuilder();
- // chartBuilder.title = "The Magic of Sensor Fusion, now with a "
- // + observer.getClass().getSimpleName();
- // var xyPosChart = chartBuilder.build();
- //
- // xyPosChart.setXAxisTitle("X pos, meters");
- // xyPosChart.setYAxisTitle("Y pos, meters");
- // xyPosChart.addSeries("Trajectory", trajXs, trajYs);
- // xyPosChart.addSeries("xHat", observerXs, observerYs);
- // new SwingWrapper<>(xyPosChart).displayChart();
- // try {
- // Thread.sleep(1000000000);
- // } catch (InterruptedException ex) {
- // ex.printStackTrace();
- // }
-
- var finalPosition = trajectory.sample(trajectory.getTotalTimeSeconds());
- assertEquals(finalPosition.poseMeters.getTranslation().getX(), observer.getXhat(0), 1.0);
- assertEquals(finalPosition.poseMeters.getTranslation().getY(), observer.getXhat(1), 1.0);
- assertEquals(finalPosition.poseMeters.getRotation().getRadians(), observer.getXhat(2), 1.0);
- assertEquals(0.0, observer.getXhat(3), 1.0);
- assertEquals(0.0, observer.getXhat(4), 1.0);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/KalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/KalmanFilterTest.java
deleted file mode 100644
index 2a434fb..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/KalmanFilterTest.java
+++ /dev/null
@@ -1,258 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.system.LinearSystem;
-import edu.wpi.first.wpilibj.system.plant.DCMotor;
-import edu.wpi.first.wpilibj.system.plant.LinearSystemId;
-import edu.wpi.first.wpilibj.trajectory.TrajectoryConfig;
-import edu.wpi.first.wpilibj.trajectory.TrajectoryGenerator;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-import edu.wpi.first.wpiutil.math.numbers.N3;
-import edu.wpi.first.wpiutil.math.numbers.N6;
-
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-public class KalmanFilterTest {
- private static LinearSystem<N2, N1, N1> elevatorPlant;
-
- private static final double kDt = 0.00505;
-
- static {
- createElevator();
- }
-
- @SuppressWarnings("LocalVariableName")
- public static void createElevator() {
- var motors = DCMotor.getVex775Pro(2);
-
- var m = 5.0;
- var r = 0.0181864;
- var G = 1.0;
- elevatorPlant = LinearSystemId.createElevatorSystem(motors, m, r, G);
- }
-
- // A swerve drive system where the states are [x, y, theta, vx, vy, vTheta]^T,
- // Y is [x, y, theta]^T and u is [ax, ay, alpha}^T
- LinearSystem<N6, N3, N3> m_swerveObserverSystem = new LinearSystem<>(
- Matrix.mat(Nat.N6(), Nat.N6()).fill( // A
- 0, 0, 0, 1, 0, 0,
- 0, 0, 0, 0, 1, 0,
- 0, 0, 0, 0, 0, 1,
- 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0),
- Matrix.mat(Nat.N6(), Nat.N3()).fill( // B
- 0, 0, 0,
- 0, 0, 0,
- 0, 0, 0,
- 1, 0, 0,
- 0, 1, 0,
- 0, 0, 1
- ),
- Matrix.mat(Nat.N3(), Nat.N6()).fill( // C
- 1, 0, 0, 0, 0, 0,
- 0, 1, 0, 0, 0, 0,
- 0, 0, 1, 0, 0, 0
- ),
- new Matrix<>(Nat.N3(), Nat.N3())); // D
-
- @Test
- @SuppressWarnings("LocalVariableName")
- public void testElevatorKalmanFilter() {
-
- var Q = VecBuilder.fill(0.05, 1.0);
- var R = VecBuilder.fill(0.0001);
-
- assertDoesNotThrow(() -> new KalmanFilter<>(Nat.N2(), Nat.N1(), elevatorPlant, Q, R, kDt));
- }
-
- @Test
- @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
- public void testSwerveKFStationary() {
-
- var random = new Random();
-
- var filter = new KalmanFilter<>(Nat.N6(), Nat.N3(),
- m_swerveObserverSystem,
- VecBuilder.fill(0.1, 0.1, 0.1, 0.1, 0.1, 0.1), // state
- // weights
- VecBuilder.fill(2, 2, 2), // measurement weights
- 0.020
- );
-
- List<Double> xhatsX = new ArrayList<>();
- List<Double> measurementsX = new ArrayList<>();
- List<Double> xhatsY = new ArrayList<>();
- List<Double> measurementsY = new ArrayList<>();
-
- Matrix<N3, N1> measurement;
- for (int i = 0; i < 100; i++) {
- // the robot is at [0, 0, 0] so we just park here
- measurement = VecBuilder.fill(
- random.nextGaussian(), random.nextGaussian(), random.nextGaussian());
- filter.correct(VecBuilder.fill(0.0, 0.0, 0.0), measurement);
-
- // we continue to not accelerate
- filter.predict(VecBuilder.fill(0.0, 0.0, 0.0), 0.020);
-
- measurementsX.add(measurement.get(0, 0));
- measurementsY.add(measurement.get(1, 0));
- xhatsX.add(filter.getXhat(0));
- xhatsY.add(filter.getXhat(1));
- }
-
- //var chart = new XYChartBuilder().build();
- //chart.addSeries("Xhat, x/y", xhatsX, xhatsY);
- //chart.addSeries("Measured position, x/y", measurementsX, measurementsY);
- //try {
- // new SwingWrapper<>(chart).displayChart();
- // Thread.sleep(10000000);
- //} catch (Exception ign) {
- //}
- assertEquals(0.0, filter.getXhat(0), 0.3);
- assertEquals(0.0, filter.getXhat(0), 0.3);
- }
-
- @Test
- @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops")
- public void testSwerveKFMovingWithoutAccelerating() {
-
- var random = new Random();
-
- var filter = new KalmanFilter<>(Nat.N6(), Nat.N3(),
- m_swerveObserverSystem,
- VecBuilder.fill(0.1, 0.1, 0.1, 0.1, 0.1, 0.1), // state
- // weights
- VecBuilder.fill(4, 4, 4), // measurement weights
- 0.020
- );
-
- List<Double> xhatsX = new ArrayList<>();
- List<Double> measurementsX = new ArrayList<>();
- List<Double> xhatsY = new ArrayList<>();
- List<Double> measurementsY = new ArrayList<>();
-
- // we set the velocity of the robot so that it's moving forward slowly
- filter.setXhat(0, 0.5);
- filter.setXhat(1, 0.5);
-
- for (int i = 0; i < 300; i++) {
- // the robot is at [0, 0, 0] so we just park here
- var measurement = VecBuilder.fill(
- random.nextGaussian() / 10d,
- random.nextGaussian() / 10d,
- random.nextGaussian() / 4d // std dev of [1, 1, 1]
- );
-
- filter.correct(VecBuilder.fill(0.0, 0.0, 0.0), measurement);
-
- // we continue to not accelerate
- filter.predict(VecBuilder.fill(0.0, 0.0, 0.0), 0.020);
-
- measurementsX.add(measurement.get(0, 0));
- measurementsY.add(measurement.get(1, 0));
- xhatsX.add(filter.getXhat(0));
- xhatsY.add(filter.getXhat(1));
- }
-
- //var chart = new XYChartBuilder().build();
- //chart.addSeries("Xhat, x/y", xhatsX, xhatsY);
- //chart.addSeries("Measured position, x/y", measurementsX, measurementsY);
- //try {
- // new SwingWrapper<>(chart).displayChart();
- // Thread.sleep(10000000);
- //} catch (Exception ign) {}
-
- assertEquals(0.0, filter.getXhat(0), 0.2);
- assertEquals(0.0, filter.getXhat(1), 0.2);
- }
-
- @Test
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
- public void testSwerveKFMovingOverTrajectory() {
-
- var random = new Random();
-
- var filter = new KalmanFilter<>(Nat.N6(), Nat.N3(),
- m_swerveObserverSystem,
- VecBuilder.fill(0.1, 0.1, 0.1, 0.1, 0.1, 0.1), // state
- // weights
- VecBuilder.fill(4, 4, 4), // measurement weights
- 0.020
- );
-
- List<Double> xhatsX = new ArrayList<>();
- List<Double> measurementsX = new ArrayList<>();
- List<Double> xhatsY = new ArrayList<>();
- List<Double> measurementsY = new ArrayList<>();
-
- var trajectory = TrajectoryGenerator.generateTrajectory(
- List.of(
- new Pose2d(0, 0, new Rotation2d()),
- new Pose2d(5, 5, new Rotation2d())
- ), new TrajectoryConfig(2, 2)
- );
- var time = 0.0;
- var lastVelocity = VecBuilder.fill(0.0, 0.0, 0.0);
-
- while (time <= trajectory.getTotalTimeSeconds()) {
- var sample = trajectory.sample(time);
- var measurement = VecBuilder.fill(
- sample.poseMeters.getTranslation().getX() + random.nextGaussian() / 5d,
- sample.poseMeters.getTranslation().getY() + random.nextGaussian() / 5d,
- sample.poseMeters.getRotation().getRadians() + random.nextGaussian() / 3d
- );
-
- var velocity = VecBuilder.fill(
- sample.velocityMetersPerSecond * sample.poseMeters.getRotation().getCos(),
- sample.velocityMetersPerSecond * sample.poseMeters.getRotation().getSin(),
- sample.curvatureRadPerMeter * sample.velocityMetersPerSecond
- );
- var u = (velocity.minus(lastVelocity)).div(0.020);
- lastVelocity = velocity;
-
- filter.correct(u, measurement);
- filter.predict(u, 0.020);
-
- measurementsX.add(measurement.get(0, 0));
- measurementsY.add(measurement.get(1, 0));
- xhatsX.add(filter.getXhat(0));
- xhatsY.add(filter.getXhat(1));
-
- time += 0.020;
- }
-
- //var chart = new XYChartBuilder().build();
- //chart.addSeries("Xhat, x/y", xhatsX, xhatsY);
- //chart.addSeries("Measured position, x/y", measurementsX, measurementsY);
- //try {
- // new SwingWrapper<>(chart).displayChart();
- // Thread.sleep(10000000);
- //} catch (Exception ign) {
- //}
-
- assertEquals(trajectory.sample(trajectory.getTotalTimeSeconds()).poseMeters
- .getTranslation().getX(), filter.getXhat(0), 0.2);
- assertEquals(trajectory.sample(trajectory.getTotalTimeSeconds()).poseMeters
- .getTranslation().getY(), filter.getXhat(1), 0.2);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/MerweScaledSigmaPointsTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/MerweScaledSigmaPointsTest.java
deleted file mode 100644
index 529a3c5..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/MerweScaledSigmaPointsTest.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class MerweScaledSigmaPointsTest {
- @Test
- public void testZeroMeanPoints() {
- var merweScaledSigmaPoints = new MerweScaledSigmaPoints<>(Nat.N2());
- var points = merweScaledSigmaPoints.sigmaPoints(VecBuilder.fill(0, 0),
- Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1));
-
- assertTrue(points.isEqual(Matrix.mat(Nat.N2(), Nat.N5()).fill(
- 0.0, 0.00173205, 0.0, -0.00173205, 0.0,
- 0.0, 0.0, 0.00173205, 0.0, -0.00173205
- ), 1E-6));
- }
-
- @Test
- public void testNonzeroMeanPoints() {
- var merweScaledSigmaPoints = new MerweScaledSigmaPoints<>(Nat.N2());
- var points = merweScaledSigmaPoints.sigmaPoints(VecBuilder.fill(1, 2),
- Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 10));
-
- assertTrue(points.isEqual(Matrix.mat(Nat.N2(), Nat.N5()).fill(
- 1.0, 1.00173205, 1.0, 0.99826795, 1.0,
- 2.0, 2.0, 2.00547723, 2.0, 1.99452277
- ), 1E-6));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/UnscentedKalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/UnscentedKalmanFilterTest.java
deleted file mode 100644
index c4340ae..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/estimator/UnscentedKalmanFilterTest.java
+++ /dev/null
@@ -1,396 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.estimator;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.math.Discretization;
-import edu.wpi.first.wpilibj.math.StateSpaceUtil;
-import edu.wpi.first.wpilibj.system.NumericalJacobian;
-import edu.wpi.first.wpilibj.system.RungeKutta;
-import edu.wpi.first.wpilibj.system.plant.DCMotor;
-import edu.wpi.first.wpilibj.system.plant.LinearSystemId;
-import edu.wpi.first.wpilibj.trajectory.TrajectoryConfig;
-import edu.wpi.first.wpilibj.trajectory.TrajectoryGenerator;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-import edu.wpi.first.wpiutil.math.numbers.N4;
-import edu.wpi.first.wpiutil.math.numbers.N6;
-
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-
-public class UnscentedKalmanFilterTest {
- @SuppressWarnings({"LocalVariableName", "ParameterName"})
- public static Matrix<N6, N1> getDynamics(Matrix<N6, N1> x, Matrix<N2, N1> u) {
- var motors = DCMotor.getCIM(2);
-
- var gHigh = 7.08;
- var rb = 0.8382 / 2.0;
- var r = 0.0746125;
- var m = 63.503;
- var J = 5.6;
-
- var C1 = -Math.pow(gHigh, 2) * motors.m_KtNMPerAmp
- / (motors.m_KvRadPerSecPerVolt * motors.m_rOhms * r * r);
- var C2 = gHigh * motors.m_KtNMPerAmp / (motors.m_rOhms * r);
-
- var c = x.get(2, 0);
- var s = x.get(3, 0);
- var vl = x.get(4, 0);
- var vr = x.get(5, 0);
-
- var Vl = u.get(0, 0);
- var Vr = u.get(1, 0);
-
- var k1 = 1.0 / m + rb * rb / J;
- var k2 = 1.0 / m - rb * rb / J;
-
- var xvel = (vl + vr) / 2;
- var w = (vr - vl) / (2.0 * rb);
-
- return VecBuilder.fill(
- xvel * c,
- xvel * s,
- -s * w,
- c * w,
- k1 * ((C1 * vl) + (C2 * Vl)) + k2 * ((C1 * vr) + (C2 * Vr)),
- k2 * ((C1 * vl) + (C2 * Vl)) + k1 * ((C1 * vr) + (C2 * Vr))
- );
- }
-
- @SuppressWarnings("ParameterName")
- public static Matrix<N4, N1> getLocalMeasurementModel(Matrix<N6, N1> x, Matrix<N2, N1> u) {
- return VecBuilder.fill(x.get(2, 0), x.get(3, 0), x.get(4, 0), x.get(5, 0));
- }
-
- @SuppressWarnings("ParameterName")
- public static Matrix<N6, N1> getGlobalMeasurementModel(Matrix<N6, N1> x, Matrix<N2, N1> u) {
- return x.copy();
- }
-
- @Test
- @SuppressWarnings("LocalVariableName")
- public void testInit() {
- assertDoesNotThrow(() -> {
- UnscentedKalmanFilter<N6, N2, N4> observer = new UnscentedKalmanFilter<>(
- Nat.N6(), Nat.N4(),
- UnscentedKalmanFilterTest::getDynamics,
- UnscentedKalmanFilterTest::getLocalMeasurementModel,
- VecBuilder.fill(0.5, 0.5, 0.7, 0.7, 1.0, 1.0),
- VecBuilder.fill(0.001, 0.001, 0.5, 0.5),
- 0.00505);
-
- var u = VecBuilder.fill(12.0, 12.0);
- observer.predict(u, 0.00505);
-
- var localY = getLocalMeasurementModel(observer.getXhat(), u);
- observer.correct(u, localY);
- });
- }
-
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops",
- "PMD.ExcessiveMethodLength"})
- @Test
- public void testConvergence() {
- double dtSeconds = 0.00505;
- double rbMeters = 0.8382 / 2.0; // Robot radius
-
- List<Double> trajXs = new ArrayList<>();
- List<Double> trajYs = new ArrayList<>();
-
- List<Double> observerXs = new ArrayList<>();
- List<Double> observerYs = new ArrayList<>();
- List<Double> observerC = new ArrayList<>();
- List<Double> observerS = new ArrayList<>();
- List<Double> observervl = new ArrayList<>();
- List<Double> observervr = new ArrayList<>();
-
- List<Double> inputVl = new ArrayList<>();
- List<Double> inputVr = new ArrayList<>();
-
- List<Double> timeData = new ArrayList<>();
- List<Matrix<?, ?>> rdots = new ArrayList<>();
-
- UnscentedKalmanFilter<N6, N2, N4> observer = new UnscentedKalmanFilter<>(
- Nat.N6(), Nat.N4(),
- UnscentedKalmanFilterTest::getDynamics,
- UnscentedKalmanFilterTest::getLocalMeasurementModel,
- VecBuilder.fill(0.5, 0.5, 0.7, 0.7, 1.0, 1.0),
- VecBuilder.fill(0.001, 0.001, 0.5, 0.5),
- dtSeconds);
-
- List<Pose2d> waypoints = Arrays.asList(new Pose2d(2.75, 22.521, new Rotation2d()),
- new Pose2d(24.73, 19.68, Rotation2d.fromDegrees(5.846)));
- var trajectory = TrajectoryGenerator.generateTrajectory(
- waypoints,
- new TrajectoryConfig(8.8, 0.1)
- );
-
- Matrix<N6, N1> nextR;
- Matrix<N2, N1> u = new Matrix<>(Nat.N2(), Nat.N1());
-
- var B = NumericalJacobian.numericalJacobianU(Nat.N6(), Nat.N2(),
- UnscentedKalmanFilterTest::getDynamics, new Matrix<>(Nat.N6(), Nat.N1()), u);
-
- observer.setXhat(VecBuilder.fill(2.75, 22.521, 1.0, 0.0, 0.0, 0.0)); // TODO not hard code this
-
- var ref = trajectory.sample(0.0);
-
- Matrix<N6, N1> r = VecBuilder.fill(
- ref.poseMeters.getTranslation().getX(),
- ref.poseMeters.getTranslation().getY(),
- ref.poseMeters.getRotation().getCos(),
- ref.poseMeters.getRotation().getSin(),
- ref.velocityMetersPerSecond * (1 - (ref.curvatureRadPerMeter * rbMeters)),
- ref.velocityMetersPerSecond * (1 + (ref.curvatureRadPerMeter * rbMeters))
- );
- nextR = r.copy();
-
- var trueXhat = observer.getXhat();
-
- double totalTime = trajectory.getTotalTimeSeconds();
- for (int i = 0; i < (totalTime / dtSeconds); i++) {
-
- ref = trajectory.sample(dtSeconds * i);
- double vl = ref.velocityMetersPerSecond * (1 - (ref.curvatureRadPerMeter * rbMeters));
- double vr = ref.velocityMetersPerSecond * (1 + (ref.curvatureRadPerMeter * rbMeters));
-
- nextR.set(0, 0, ref.poseMeters.getTranslation().getX());
- nextR.set(1, 0, ref.poseMeters.getTranslation().getY());
- nextR.set(2, 0, ref.poseMeters.getRotation().getCos());
- nextR.set(3, 0, ref.poseMeters.getRotation().getSin());
- nextR.set(4, 0, vl);
- nextR.set(5, 0, vr);
-
- Matrix<N4, N1> localY =
- getLocalMeasurementModel(trueXhat, new Matrix<>(Nat.N2(), Nat.N1()));
- var noiseStdDev = VecBuilder.fill(0.001, 0.001, 0.5, 0.5);
-
- observer.correct(u,
- localY.plus(StateSpaceUtil.makeWhiteNoiseVector(
- noiseStdDev)));
-
- var rdot = nextR.minus(r).div(dtSeconds);
- u = new Matrix<>(B.solve(rdot.minus(getDynamics(r, new Matrix<>(Nat.N2(), Nat.N1())))));
-
- rdots.add(rdot);
-
- trajXs.add(ref.poseMeters.getTranslation().getX());
- trajYs.add(ref.poseMeters.getTranslation().getY());
-
- observerXs.add(observer.getXhat().get(0, 0));
- observerYs.add(observer.getXhat().get(1, 0));
-
- observerC.add(observer.getXhat(2));
- observerS.add(observer.getXhat(3));
-
- observervl.add(observer.getXhat(4));
- observervr.add(observer.getXhat(5));
-
- inputVl.add(u.get(0, 0));
- inputVr.add(u.get(1, 0));
-
- timeData.add(i * dtSeconds);
-
- r = nextR;
- observer.predict(u, dtSeconds);
- trueXhat = RungeKutta.rungeKutta(UnscentedKalmanFilterTest::getDynamics,
- trueXhat, u, dtSeconds);
- }
-
- var localY = getLocalMeasurementModel(trueXhat, u);
- observer.correct(u, localY);
-
- var globalY = getGlobalMeasurementModel(trueXhat, u);
- var R = StateSpaceUtil.makeCostMatrix(
- VecBuilder.fill(0.01, 0.01, 0.0001, 0.0001, 0.5, 0.5));
- observer.correct(Nat.N6(), u, globalY,
- UnscentedKalmanFilterTest::getGlobalMeasurementModel, R);
-
- final var finalPosition = trajectory.sample(trajectory.getTotalTimeSeconds());
-
- // var chartBuilder = new XYChartBuilder();
- // chartBuilder.title = "The Magic of Sensor Fusion, now with a "
- // + observer.getClass().getSimpleName();
- // var xyPosChart = chartBuilder.build();
-
- // xyPosChart.setXAxisTitle("X pos, meters");
- // xyPosChart.setYAxisTitle("Y pos, meters");
- // xyPosChart.addSeries("Trajectory", trajXs, trajYs);
- // xyPosChart.addSeries("xHat", observerXs, observerYs);
-
- // var stateChart = new XYChartBuilder()
- // .title("States (x-hat)").build();
- // stateChart.addSeries("Cos", timeData, observerC);
- // stateChart.addSeries("Sin", timeData, observerS);
- // stateChart.addSeries("vl, m/s", timeData, observervl);
- // stateChart.addSeries("vr, m/s", timeData, observervr);
-
- // var inputChart = new XYChartBuilder().title("Inputs").build();
- // inputChart.addSeries("Left voltage", timeData, inputVl);
- // inputChart.addSeries("Right voltage", timeData, inputVr);
-
- // var rdotChart = new XYChartBuilder().title("Rdot").build();
- // rdotChart.addSeries("xdot, or vx", timeData, rdots.stream().map(it -> it.get(0, 0))
- // .collect(Collectors.toList()));
- // rdotChart.addSeries("ydot, or vy", timeData, rdots.stream().map(it -> it.get(1, 0))
- // .collect(Collectors.toList()));
- // rdotChart.addSeries("cos dot", timeData, rdots.stream().map(it -> it.get(2, 0))
- // .collect(Collectors.toList()));
- // rdotChart.addSeries("sin dot", timeData, rdots.stream().map(it -> it.get(3, 0))
- // .collect(Collectors.toList()));
- // rdotChart.addSeries("vl dot, or al", timeData, rdots.stream().map(it -> it.get(4, 0))
- // .collect(Collectors.toList()));
- // rdotChart.addSeries("vr dot, or ar", timeData, rdots.stream().map(it -> it.get(5, 0))
- // .collect(Collectors.toList()));
-
- // List<XYChart> charts = new ArrayList<>();
- // charts.add(xyPosChart);
- // charts.add(stateChart);
- // charts.add(inputChart);
- // charts.add(rdotChart);
- // new SwingWrapper<>(charts).displayChartMatrix();
- // try {
- // Thread.sleep(1000000000);
- // } catch (InterruptedException ex) {
- // ex.printStackTrace();
- // }
-
- assertEquals(finalPosition.poseMeters.getTranslation().getX(), observer.getXhat(0), 0.25);
- assertEquals(finalPosition.poseMeters.getTranslation().getY(), observer.getXhat(1), 0.25);
- assertEquals(finalPosition.poseMeters.getRotation().getRadians(), observer.getXhat(2), 1.0);
- assertEquals(0.0, observer.getXhat(3), 1.0);
- assertEquals(0.0, observer.getXhat(4), 1.0);
- }
-
- @Test
- @SuppressWarnings({"LocalVariableName", "ParameterName", "PMD.AvoidInstantiatingObjectsInLoops"})
- public void testLinearUKF() {
- var dt = 0.020;
- var plant = LinearSystemId.identifyVelocitySystem(0.02, 0.006);
- var observer = new UnscentedKalmanFilter<>(Nat.N1(), Nat.N1(),
- (x, u) -> plant.getA().times(x).plus(plant.getB().times(u)),
- plant::calculateY,
- VecBuilder.fill(0.05),
- VecBuilder.fill(1.0),
- dt);
-
- var time = new ArrayList<Double>();
- var refData = new ArrayList<Double>();
- var xhat = new ArrayList<Double>();
- var udata = new ArrayList<Double>();
- var xdotData = new ArrayList<Double>();
-
- var discABPair = Discretization.discretizeAB(plant.getA(), plant.getB(), dt);
- var discA = discABPair.getFirst();
- var discB = discABPair.getSecond();
-
- Matrix<N1, N1> ref = VecBuilder.fill(100);
- Matrix<N1, N1> u = VecBuilder.fill(0);
-
- Matrix<N1, N1> xdot;
- for (int i = 0; i < (2.0 / dt); i++) {
- observer.predict(u, dt);
-
- u = discB.solve(ref.minus(discA.times(ref)));
-
- xdot = plant.getA().times(observer.getXhat()).plus(plant.getB().times(u));
-
- time.add(i * dt);
- refData.add(ref.get(0, 0));
- xhat.add(observer.getXhat(0));
- udata.add(u.get(0, 0));
- xdotData.add(xdot.get(0, 0));
- }
-
- // var chartBuilder = new XYChartBuilder();
- // chartBuilder.title = "The Magic of Sensor Fusion";
- // var chart = chartBuilder.build();
-
- // chart.addSeries("Ref", time, refData);
- // chart.addSeries("xHat", time, xhat);
- // chart.addSeries("input", time, udata);
- //// chart.addSeries("xdot", time, xdotData);
-
- // new SwingWrapper<>(chart).displayChart();
- // try {
- // Thread.sleep(1000000000);
- // } catch (InterruptedException e) {
- // }
-
- assertEquals(ref.get(0, 0), observer.getXhat(0), 5);
- }
-
- @Test
- public void testUnscentedTransform() {
- // From FilterPy
- var ret = UnscentedKalmanFilter.unscentedTransform(Nat.N4(), Nat.N4(),
- Matrix.mat(Nat.N4(), Nat.N9()).fill(
- -0.9, -0.822540333075852, -0.8922540333075852, -0.9,
- -0.9, -0.9774596669241481, -0.9077459666924148, -0.9, -0.9,
- 1.0, 1.0, 1.077459666924148, 1.0, 1.0, 1.0, 0.9225403330758519, 1.0, 1.0,
- -0.9, -0.9, -0.9, -0.822540333075852, -0.8922540333075852, -0.9,
- -0.9, -0.9774596669241481, -0.9077459666924148,
- 1.0, 1.0, 1.0, 1.0, 1.077459666924148, 1.0, 1.0, 1.0, 0.9225403330758519
- ),
- VecBuilder.fill(
- -132.33333333,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667
- ),
- VecBuilder.fill(
- -129.34333333,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667,
- 16.66666667
- )
- );
-
- assertTrue(
- VecBuilder.fill(-0.9, 1, -0.9, 1).isEqual(
- ret.getFirst(), 1E-5
- ));
-
- assertTrue(
- Matrix.mat(Nat.N4(), Nat.N4()).fill(
- 2.02000002e-01, 2.00000500e-02, -2.69044710e-29,
- -4.59511477e-29,
- 2.00000500e-02, 2.00001000e-01, -2.98781068e-29,
- -5.12759588e-29,
- -2.73372625e-29, -3.09882635e-29, 2.02000002e-01,
- 2.00000500e-02,
- -4.67065917e-29, -5.10705197e-29, 2.00000500e-02,
- 2.00001000e-01
- ).isEqual(
- ret.getSecond(), 1E-5
- ));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Rotation2dTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Rotation2dTest.java
deleted file mode 100644
index 8a08944..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Rotation2dTest.java
+++ /dev/null
@@ -1,81 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.geometry;
-
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotEquals;
-
-class Rotation2dTest {
- private static final double kEpsilon = 1E-9;
-
- @Test
- void testRadiansToDegrees() {
- var one = new Rotation2d(Math.PI / 3);
- var two = new Rotation2d(Math.PI / 4);
-
- assertAll(
- () -> assertEquals(one.getDegrees(), 60.0, kEpsilon),
- () -> assertEquals(two.getDegrees(), 45.0, kEpsilon)
- );
- }
-
- @Test
- void testRadiansAndDegrees() {
- var one = Rotation2d.fromDegrees(45.0);
- var two = Rotation2d.fromDegrees(30.0);
-
- assertAll(
- () -> assertEquals(one.getRadians(), Math.PI / 4, kEpsilon),
- () -> assertEquals(two.getRadians(), Math.PI / 6, kEpsilon)
- );
- }
-
- @Test
- void testRotateByFromZero() {
- var zero = new Rotation2d();
- var rotated = zero.rotateBy(Rotation2d.fromDegrees(90.0));
-
- assertAll(
- () -> assertEquals(rotated.getRadians(), Math.PI / 2.0, kEpsilon),
- () -> assertEquals(rotated.getDegrees(), 90.0, kEpsilon)
- );
- }
-
- @Test
- void testRotateByNonZero() {
- var rot = Rotation2d.fromDegrees(90.0);
- rot = rot.plus(Rotation2d.fromDegrees(30.0));
-
- assertEquals(rot.getDegrees(), 120.0, kEpsilon);
- }
-
- @Test
- void testMinus() {
- var one = Rotation2d.fromDegrees(70.0);
- var two = Rotation2d.fromDegrees(30.0);
-
- assertEquals(one.minus(two).getDegrees(), 40.0, kEpsilon);
- }
-
- @Test
- void testEquality() {
- var one = Rotation2d.fromDegrees(43.0);
- var two = Rotation2d.fromDegrees(43.0);
- assertEquals(one, two);
- }
-
- @Test
- void testInequality() {
- var one = Rotation2d.fromDegrees(43.0);
- var two = Rotation2d.fromDegrees(43.5);
- assertNotEquals(one, two);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Transform2dTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Transform2dTest.java
deleted file mode 100644
index c375d22..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/geometry/Transform2dTest.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.geometry;
-
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class Transform2dTest {
- private static final double kEpsilon = 1E-9;
-
- @Test
- void testInverse() {
- var initial = new Pose2d(new Translation2d(1.0, 2.0), Rotation2d.fromDegrees(45.0));
- var transformation = new Transform2d(new Translation2d(5.0, 0.0),
- Rotation2d.fromDegrees(5.0));
-
- var transformed = initial.plus(transformation);
- var untransformed = transformed.plus(transformation.inverse());
-
- assertAll(
- () -> assertEquals(initial.getX(), untransformed.getX(),
- kEpsilon),
- () -> assertEquals(initial.getY(), untransformed.getY(),
- kEpsilon),
- () -> assertEquals(initial.getRotation().getDegrees(),
- untransformed.getRotation().getDegrees(), kEpsilon)
- );
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/ChassisSpeedsTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/ChassisSpeedsTest.java
deleted file mode 100644
index b06abe6..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/ChassisSpeedsTest.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class ChassisSpeedsTest {
- private static final double kEpsilon = 1E-9;
-
- @Test
- void testFieldRelativeConstruction() {
- final var chassisSpeeds = ChassisSpeeds.fromFieldRelativeSpeeds(
- 1.0, 0.0, 0.5, Rotation2d.fromDegrees(-90.0)
- );
-
- assertAll(
- () -> assertEquals(0.0, chassisSpeeds.vxMetersPerSecond, kEpsilon),
- () -> assertEquals(1.0, chassisSpeeds.vyMetersPerSecond, kEpsilon),
- () -> assertEquals(0.5, chassisSpeeds.omegaRadiansPerSecond, kEpsilon)
- );
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveOdometryTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveOdometryTest.java
deleted file mode 100644
index e6022de..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/DifferentialDriveOdometryTest.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class DifferentialDriveOdometryTest {
- private static final double kEpsilon = 1E-9;
- private final DifferentialDriveOdometry m_odometry = new DifferentialDriveOdometry(
- new Rotation2d());
-
- @Test
- void testOdometryWithEncoderDistances() {
- m_odometry.resetPosition(new Pose2d(), Rotation2d.fromDegrees(45));
- var pose = m_odometry.update(Rotation2d.fromDegrees(135.0), 0.0, 5 * Math.PI);
-
- assertAll(
- () -> assertEquals(pose.getX(), 5.0, kEpsilon),
- () -> assertEquals(pose.getY(), 5.0, kEpsilon),
- () -> assertEquals(pose.getRotation().getDegrees(), 90.0, kEpsilon)
- );
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveKinematicsTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveKinematicsTest.java
deleted file mode 100644
index 93c8d6a..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/kinematics/MecanumDriveKinematicsTest.java
+++ /dev/null
@@ -1,262 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.kinematics;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class MecanumDriveKinematicsTest {
- private static final double kEpsilon = 1E-9;
-
- private final Translation2d m_fl = new Translation2d(12, 12);
- private final Translation2d m_fr = new Translation2d(12, -12);
- private final Translation2d m_bl = new Translation2d(-12, 12);
- private final Translation2d m_br = new Translation2d(-12, -12);
-
- private final MecanumDriveKinematics m_kinematics =
- new MecanumDriveKinematics(m_fl, m_fr, m_bl, m_br);
-
- @Test
- void testStraightLineInverseKinematics() {
- ChassisSpeeds speeds = new ChassisSpeeds(5, 0, 0);
- var moduleStates = m_kinematics.toWheelSpeeds(speeds);
-
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl 3.535534 fr 3.535534 rl 3.535534 rr 3.535534
- */
-
- assertAll(
- () -> assertEquals(3.536, moduleStates.frontLeftMetersPerSecond, 0.1),
- () -> assertEquals(3.536, moduleStates.frontRightMetersPerSecond, 0.1),
- () -> assertEquals(3.536, moduleStates.rearLeftMetersPerSecond, 0.1),
- () -> assertEquals(3.536, moduleStates.rearRightMetersPerSecond, 0.1)
- );
- }
-
- @Test
- void testStraightLineForwardKinematicsKinematics() {
-
- var wheelSpeeds = new MecanumDriveWheelSpeeds(3.536, 3.536, 3.536, 3.536);
- var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl 3.535534 fr 3.535534 rl 3.535534 rr 3.535534 will be [[5][0][0]]
- */
-
- assertAll(
- () -> assertEquals(5, moduleStates.vxMetersPerSecond, 0.1),
- () -> assertEquals(0, moduleStates.vyMetersPerSecond, 0.1),
- () -> assertEquals(0, moduleStates.omegaRadiansPerSecond, 0.1)
- );
- }
-
- @Test
- void testStrafeInverseKinematics() {
- ChassisSpeeds speeds = new ChassisSpeeds(0, 4, 0);
- var moduleStates = m_kinematics.toWheelSpeeds(speeds);
-
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl -2.828427 fr 2.828427 rl 2.828427 rr -2.828427
- */
-
- assertAll(
- () -> assertEquals(-2.828427, moduleStates.frontLeftMetersPerSecond, 0.1),
- () -> assertEquals(2.828427, moduleStates.frontRightMetersPerSecond, 0.1),
- () -> assertEquals(2.828427, moduleStates.rearLeftMetersPerSecond, 0.1),
- () -> assertEquals(-2.828427, moduleStates.rearRightMetersPerSecond, 0.1)
- );
- }
-
- @Test
- void testStrafeForwardKinematicsKinematics() {
-
- var wheelSpeeds = new MecanumDriveWheelSpeeds(-2.828427, 2.828427, 2.828427, -2.828427);
- var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl 3.535534 fr 3.535534 rl 3.535534 rr 3.535534 will be [[5][0][0]]
- */
-
- assertAll(
- () -> assertEquals(0, moduleStates.vxMetersPerSecond, 0.1),
- () -> assertEquals(4, moduleStates.vyMetersPerSecond, 0.1),
- () -> assertEquals(0, moduleStates.omegaRadiansPerSecond, 0.1)
- );
- }
-
- @Test
- void testRotationInverseKinematics() {
- ChassisSpeeds speeds = new ChassisSpeeds(0, 0, 2 * Math.PI);
- var moduleStates = m_kinematics.toWheelSpeeds(speeds);
-
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl -106.629191 fr 106.629191 rl -106.629191 rr 106.629191
- */
-
- assertAll(
- () -> assertEquals(-106.629191, moduleStates.frontLeftMetersPerSecond, 0.1),
- () -> assertEquals(106.629191, moduleStates.frontRightMetersPerSecond, 0.1),
- () -> assertEquals(-106.629191, moduleStates.rearLeftMetersPerSecond, 0.1),
- () -> assertEquals(106.629191, moduleStates.rearRightMetersPerSecond, 0.1)
- );
- }
-
- @Test
- void testRotationForwardKinematicsKinematics() {
- var wheelSpeeds = new MecanumDriveWheelSpeeds(-106.629191, 106.629191, -106.629191, 106.629191);
- var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl -106.629191 fr 106.629191 rl -106.629191 rr 106.629191 should be [[0][0][2pi]]
- */
-
- assertAll(
- () -> assertEquals(0, moduleStates.vxMetersPerSecond, 0.1),
- () -> assertEquals(0, moduleStates.vyMetersPerSecond, 0.1),
- () -> assertEquals(2 * Math.PI, moduleStates.omegaRadiansPerSecond, 0.1)
- );
- }
-
- @Test
- void testMixedTranslationRotationInverseKinematics() {
- ChassisSpeeds speeds = new ChassisSpeeds(2, 3, 1);
- var moduleStates = m_kinematics.toWheelSpeeds(speeds);
-
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl -17.677670 fr 20.506097 rl -13.435029 rr 16.263456
- */
-
- assertAll(
- () -> assertEquals(-17.677670, moduleStates.frontLeftMetersPerSecond, 0.1),
- () -> assertEquals(20.506097, moduleStates.frontRightMetersPerSecond, 0.1),
- () -> assertEquals(-13.435, moduleStates.rearLeftMetersPerSecond, 0.1),
- () -> assertEquals(16.26, moduleStates.rearRightMetersPerSecond, 0.1)
- );
- }
-
- @Test
- void testMixedTranslationRotationForwardKinematicsKinematics() {
- var wheelSpeeds = new MecanumDriveWheelSpeeds(-17.677670, 20.51, -13.44, 16.26);
- var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl -17.677670 fr 20.506097 rl -13.435029 rr 16.263456 should be [[2][3][1]]
- */
-
- assertAll(
- () -> assertEquals(2, moduleStates.vxMetersPerSecond, 0.1),
- () -> assertEquals(3, moduleStates.vyMetersPerSecond, 0.1),
- () -> assertEquals(1, moduleStates.omegaRadiansPerSecond, 0.1)
- );
- }
-
- @Test
- void testOffCenterRotationInverseKinematics() {
- ChassisSpeeds speeds = new ChassisSpeeds(0, 0, 1);
- var moduleStates = m_kinematics.toWheelSpeeds(speeds, m_fl);
-
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl 0.000000 fr 16.970563 rl -16.970563 rr 33.941125
- */
-
- assertAll(
- () -> assertEquals(0, moduleStates.frontLeftMetersPerSecond, 0.1),
- () -> assertEquals(16.971, moduleStates.frontRightMetersPerSecond, 0.1),
- () -> assertEquals(-16.971, moduleStates.rearLeftMetersPerSecond, 0.1),
- () -> assertEquals(33.941, moduleStates.rearRightMetersPerSecond, 0.1)
- );
- }
-
- @Test
- void testOffCenterRotationForwardKinematicsKinematics() {
- var wheelSpeeds = new MecanumDriveWheelSpeeds(0, 16.971, -16.971, 33.941);
- var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from the wheel
- velocities should be [[12][-12][1]]
- */
-
- assertAll(
- () -> assertEquals(12, moduleStates.vxMetersPerSecond, 0.1),
- () -> assertEquals(-12, moduleStates.vyMetersPerSecond, 0.1),
- () -> assertEquals(1, moduleStates.omegaRadiansPerSecond, 0.1)
- );
- }
-
- @Test
- void testOffCenterTranslationRotationInverseKinematics() {
- ChassisSpeeds speeds = new ChassisSpeeds(5, 2, 1);
- var moduleStates = m_kinematics.toWheelSpeeds(speeds, m_fl);
-
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl 2.121320 fr 21.920310 rl -12.020815 rr 36.062446
- */
-
- assertAll(
- () -> assertEquals(2.12, moduleStates.frontLeftMetersPerSecond, 0.1),
- () -> assertEquals(21.92, moduleStates.frontRightMetersPerSecond, 0.1),
- () -> assertEquals(-12.02, moduleStates.rearLeftMetersPerSecond, 0.1),
- () -> assertEquals(36.06, moduleStates.rearRightMetersPerSecond, 0.1)
- );
- }
-
- @Test
- void testOffCenterRotationTranslationForwardKinematicsKinematics() {
-
- var wheelSpeeds = new MecanumDriveWheelSpeeds(2.12, 21.92, -12.02, 36.06);
- var moduleStates = m_kinematics.toChassisSpeeds(wheelSpeeds);
-
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from the wheel
- velocities should be [[17][-10][1]]
- */
-
- assertAll(
- () -> assertEquals(17, moduleStates.vxMetersPerSecond, 0.1),
- () -> assertEquals(-10, moduleStates.vyMetersPerSecond, 0.1),
- () -> assertEquals(1, moduleStates.omegaRadiansPerSecond, 0.1)
- );
- }
-
- @Test
- void testNormalize() {
- var wheelSpeeds = new MecanumDriveWheelSpeeds(5, 6, 4, 7);
- wheelSpeeds.normalize(5.5);
-
- double factor = 5.5 / 7.0;
-
- assertAll(
- () -> assertEquals(5.0 * factor, wheelSpeeds.frontLeftMetersPerSecond, kEpsilon),
- () -> assertEquals(6.0 * factor, wheelSpeeds.frontRightMetersPerSecond, kEpsilon),
- () -> assertEquals(4.0 * factor, wheelSpeeds.rearLeftMetersPerSecond, kEpsilon),
- () -> assertEquals(7.0 * factor, wheelSpeeds.rearRightMetersPerSecond, kEpsilon)
- );
- }
-
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/math/StateSpaceUtilTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/math/StateSpaceUtilTest.java
deleted file mode 100644
index 244ca53..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/math/StateSpaceUtilTest.java
+++ /dev/null
@@ -1,199 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.math;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import org.ejml.dense.row.MatrixFeatures_DDRM;
-import org.ejml.simple.SimpleMatrix;
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.SimpleMatrixUtils;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class StateSpaceUtilTest {
- @Test
- public void testCostArray() {
- var mat = StateSpaceUtil.makeCostMatrix(
- VecBuilder.fill(1.0, 2.0, 3.0));
-
- assertEquals(1.0, mat.get(0, 0), 1e-3);
- assertEquals(0.0, mat.get(0, 1), 1e-3);
- assertEquals(0.0, mat.get(0, 2), 1e-3);
- assertEquals(0.0, mat.get(1, 0), 1e-3);
- assertEquals(1.0 / 4.0, mat.get(1, 1), 1e-3);
- assertEquals(0.0, mat.get(1, 2), 1e-3);
- assertEquals(0.0, mat.get(0, 2), 1e-3);
- assertEquals(0.0, mat.get(1, 2), 1e-3);
- assertEquals(1.0 / 9.0, mat.get(2, 2), 1e-3);
- }
-
- @Test
- public void testCovArray() {
- var mat = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(),
- VecBuilder.fill(1.0, 2.0, 3.0));
-
- assertEquals(1.0, mat.get(0, 0), 1e-3);
- assertEquals(0.0, mat.get(0, 1), 1e-3);
- assertEquals(0.0, mat.get(0, 2), 1e-3);
- assertEquals(0.0, mat.get(1, 0), 1e-3);
- assertEquals(4.0, mat.get(1, 1), 1e-3);
- assertEquals(0.0, mat.get(1, 2), 1e-3);
- assertEquals(0.0, mat.get(0, 2), 1e-3);
- assertEquals(0.0, mat.get(1, 2), 1e-3);
- assertEquals(9.0, mat.get(2, 2), 1e-3);
- }
-
- @Test
- @SuppressWarnings("LocalVariableName")
- public void testIsStabilizable() {
- Matrix<N2, N2> A;
- Matrix<N2, N1> B = VecBuilder.fill(0, 1);
-
- // First eigenvalue is uncontrollable and unstable.
- // Second eigenvalue is controllable and stable.
- A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.2, 0, 0, 0.5);
- assertFalse(StateSpaceUtil.isStabilizable(A, B));
-
- // First eigenvalue is uncontrollable and marginally stable.
- // Second eigenvalue is controllable and stable.
- A = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 0.5);
- assertFalse(StateSpaceUtil.isStabilizable(A, B));
-
- // First eigenvalue is uncontrollable and stable.
- // Second eigenvalue is controllable and stable.
- A = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.2, 0, 0, 0.5);
- assertTrue(StateSpaceUtil.isStabilizable(A, B));
-
- // First eigenvalue is uncontrollable and stable.
- // Second eigenvalue is controllable and unstable.
- A = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.2, 0, 0, 1.2);
- assertTrue(StateSpaceUtil.isStabilizable(A, B));
- }
-
- @Test
- public void testMakeWhiteNoiseVector() {
- var firstData = new ArrayList<Double>();
- var secondData = new ArrayList<Double>();
- for (int i = 0; i < 1000; i++) {
- var noiseVec = StateSpaceUtil.makeWhiteNoiseVector(VecBuilder.fill(1.0, 2.0));
- firstData.add(noiseVec.get(0, 0));
- secondData.add(noiseVec.get(1, 0));
- }
- assertEquals(1.0, calculateStandardDeviation(firstData), 0.2);
- assertEquals(2.0, calculateStandardDeviation(secondData), 0.2);
- }
-
- private double calculateStandardDeviation(List<Double> numArray) {
- double sum = 0.0;
- double standardDeviation = 0.0;
- int length = numArray.size();
-
- for (double num : numArray) {
- sum += num;
- }
-
- double mean = sum / length;
-
- for (double num : numArray) {
- standardDeviation += Math.pow(num - mean, 2);
- }
-
- return Math.sqrt(standardDeviation / length);
- }
-
- @Test
- public void testDiscretizeA() {
- var contA = Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1, 0, 0);
- var x0 = VecBuilder.fill(1, 1);
- var discA = Discretization.discretizeA(contA, 1.0);
- var x1Discrete = discA.times(x0);
-
- // We now have pos = vel = 1 and accel = 0, which should give us:
- var x1Truth = VecBuilder.fill(x0.get(0, 0) + 1.0 * x0.get(1, 0),
- x0.get(1, 0));
- assertTrue(x1Truth.isEqual(x1Discrete, 1E-4));
- }
-
- @SuppressWarnings("LocalVariableName")
- @Test
- public void testDiscretizeAB() {
- var contA = Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1, 0, 0);
- var contB = VecBuilder.fill(0, 1);
- var x0 = VecBuilder.fill(1, 1);
- var u = VecBuilder.fill(1);
-
- var abPair = Discretization.discretizeAB(contA, contB, 1.0);
-
- var x1Discrete = abPair.getFirst().times(x0).plus(abPair.getSecond().times(u));
-
- // We now have pos = vel = accel = 1, which should give us:
- var x1Truth = VecBuilder.fill(x0.get(0, 0) + x0.get(1, 0) + 0.5 * u.get(0, 0), x0.get(0, 0)
- + u.get(0, 0));
-
- assertTrue(x1Truth.isEqual(x1Discrete, 1E-4));
- }
-
- @Test
- public void testMatrixExp() {
- Matrix<N2, N2> wrappedMatrix = Matrix.eye(Nat.N2());
- var wrappedResult = wrappedMatrix.exp();
-
- assertTrue(wrappedResult.isEqual(
- Matrix.mat(Nat.N2(), Nat.N2()).fill(Math.E, 0, 0, Math.E), 1E-9));
-
- var matrix = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 2, 3, 4);
- wrappedResult = matrix.times(0.01).exp();
-
- assertTrue(wrappedResult.isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(1.01035625, 0.02050912,
- 0.03076368, 1.04111993), 1E-8));
- }
-
- @Test
- public void testSimpleMatrixExp() {
- SimpleMatrix matrix = SimpleMatrixUtils.eye(2);
- var result = SimpleMatrixUtils.exp(matrix);
-
- assertTrue(MatrixFeatures_DDRM.isIdentical(
- result.getDDRM(),
- new SimpleMatrix(2, 2, true, new double[]{Math.E, 0, 0, Math.E}).getDDRM(),
- 1E-9
- ));
-
- matrix = new SimpleMatrix(2, 2, true, new double[]{1, 2, 3, 4});
- result = SimpleMatrixUtils.exp(matrix.scale(0.01));
-
- assertTrue(MatrixFeatures_DDRM.isIdentical(
- result.getDDRM(),
- new SimpleMatrix(2, 2, true, new double[]{1.01035625, 0.02050912,
- 0.03076368, 1.04111993}).getDDRM(),
- 1E-8
- ));
- }
-
- @Test
- public void testPoseToVector() {
- Pose2d pose = new Pose2d(1, 2, new Rotation2d(3));
- var vector = StateSpaceUtil.poseToVector(pose);
- assertEquals(pose.getTranslation().getX(), vector.get(0, 0), 1e-6);
- assertEquals(pose.getTranslation().getY(), vector.get(1, 0), 1e-6);
- assertEquals(pose.getRotation().getRadians(), vector.get(2, 0), 1e-6);
- }
-
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/spline/CubicHermiteSplineTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/spline/CubicHermiteSplineTest.java
deleted file mode 100644
index 710bab5..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/spline/CubicHermiteSplineTest.java
+++ /dev/null
@@ -1,162 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.spline;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-import edu.wpi.first.wpilibj.spline.SplineParameterizer.MalformedSplineException;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-
-class CubicHermiteSplineTest {
- private static final double kMaxDx = 0.127;
- private static final double kMaxDy = 0.00127;
- private static final double kMaxDtheta = 0.0872;
-
- @SuppressWarnings({"ParameterName", "PMD.UnusedLocalVariable"})
- private void run(Pose2d a, List<Translation2d> waypoints, Pose2d b) {
- // Start the timer.
- //var start = System.nanoTime();
-
- // Generate and parameterize the spline.
- var controlVectors =
- SplineHelper.getCubicControlVectorsFromWaypoints(a,
- waypoints.toArray(new Translation2d[0]), b);
- var splines
- = SplineHelper.getCubicSplinesFromControlVectors(
- controlVectors[0], waypoints.toArray(new Translation2d[0]), controlVectors[1]);
-
- var poses = new ArrayList<PoseWithCurvature>();
-
- poses.add(splines[0].getPoint(0.0));
-
- for (var spline : splines) {
- poses.addAll(SplineParameterizer.parameterize(spline));
- }
-
- // End the timer.
- //var end = System.nanoTime();
-
- // Calculate the duration (used when benchmarking)
- //var durationMicroseconds = (end - start) / 1000.0;
-
- for (int i = 0; i < poses.size() - 1; i++) {
- var p0 = poses.get(i);
- var p1 = poses.get(i + 1);
-
- // Make sure the twist is under the tolerance defined by the Spline class.
- var twist = p0.poseMeters.log(p1.poseMeters);
- assertAll(
- () -> assertTrue(Math.abs(twist.dx) < kMaxDx),
- () -> assertTrue(Math.abs(twist.dy) < kMaxDy),
- () -> assertTrue(Math.abs(twist.dtheta) < kMaxDtheta)
- );
- }
-
- // Check first point
- assertAll(
- () -> assertEquals(a.getX(),
- poses.get(0).poseMeters.getX(), 1E-9),
- () -> assertEquals(a.getY(),
- poses.get(0).poseMeters.getY(), 1E-9),
- () -> assertEquals(a.getRotation().getRadians(),
- poses.get(0).poseMeters.getRotation().getRadians(), 1E-9)
- );
-
- // Check interior waypoints
- boolean interiorsGood = true;
- for (var waypoint : waypoints) {
- boolean found = false;
- for (var state : poses) {
- if (waypoint.getDistance(state.poseMeters.getTranslation()) == 0) {
- found = true;
- }
- }
- interiorsGood &= found;
- }
-
- assertTrue(interiorsGood);
-
- // Check last point
- assertAll(
- () -> assertEquals(b.getX(),
- poses.get(poses.size() - 1).poseMeters.getX(), 1E-9),
- () -> assertEquals(b.getY(),
- poses.get(poses.size() - 1).poseMeters.getY(), 1E-9),
- () -> assertEquals(b.getRotation().getRadians(),
- poses.get(poses.size() - 1).poseMeters.getRotation().getRadians(), 1E-9)
- );
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testStraightLine() {
- run(new Pose2d(), new ArrayList<>(), new Pose2d(3, 0, new Rotation2d()));
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testSCurve() {
- var start = new Pose2d(0, 0, Rotation2d.fromDegrees(90.0));
- ArrayList<Translation2d> waypoints = new ArrayList<>();
- waypoints.add(new Translation2d(1, 1));
- waypoints.add(new Translation2d(2, -1));
- var end = new Pose2d(3, 0, Rotation2d.fromDegrees(90.0));
-
- run(start, waypoints, end);
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testOneInterior() {
- var start = new Pose2d(0, 0, Rotation2d.fromDegrees(0.0));
- ArrayList<Translation2d> waypoints = new ArrayList<>();
- waypoints.add(new Translation2d(2.0, 0.0));
- var end = new Pose2d(4, 0, Rotation2d.fromDegrees(0.0));
-
- run(start, waypoints, end);
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testWindyPath() {
- final var start = new Pose2d(0, 0, Rotation2d.fromDegrees(0.0));
- final ArrayList<Translation2d> waypoints = new ArrayList<>();
- waypoints.add(new Translation2d(0.5, 0.5));
- waypoints.add(new Translation2d(0.5, 0.5));
- waypoints.add(new Translation2d(1.0, 0.0));
- waypoints.add(new Translation2d(1.5, 0.5));
- waypoints.add(new Translation2d(2.0, 0.0));
- waypoints.add(new Translation2d(2.5, 0.5));
- final var end = new Pose2d(3.0, 0.0, Rotation2d.fromDegrees(0.0));
-
- run(start, waypoints, end);
- }
-
- @Test
- void testMalformed() {
- assertThrows(MalformedSplineException.class, () -> run(
- new Pose2d(0, 0, Rotation2d.fromDegrees(0)),
- new ArrayList<>(), new Pose2d(1, 0, Rotation2d.fromDegrees(180))));
- assertThrows(MalformedSplineException.class, () -> run(
- new Pose2d(10, 10, Rotation2d.fromDegrees(90)),
- Arrays.asList(new Translation2d(10, 10.5)),
- new Pose2d(10, 11, Rotation2d.fromDegrees(-90))));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/spline/QuinticHermiteSplineTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/spline/QuinticHermiteSplineTest.java
deleted file mode 100644
index 3b35db0..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/spline/QuinticHermiteSplineTest.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.spline;
-
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.spline.SplineParameterizer.MalformedSplineException;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-class QuinticHermiteSplineTest {
- private static final double kMaxDx = 0.127;
- private static final double kMaxDy = 0.00127;
- private static final double kMaxDtheta = 0.0872;
-
- @SuppressWarnings({ "ParameterName", "PMD.UnusedLocalVariable" })
- private void run(Pose2d a, Pose2d b) {
- // Start the timer.
- //var start = System.nanoTime();
-
- // Generate and parameterize the spline.
- var spline = SplineHelper.getQuinticSplinesFromWaypoints(List.of(a, b))[0];
- var poses = SplineParameterizer.parameterize(spline);
-
- // End the timer.
- //var end = System.nanoTime();
-
- // Calculate the duration (used when benchmarking)
- //var durationMicroseconds = (end - start) / 1000.0;
-
- for (int i = 0; i < poses.size() - 1; i++) {
- var p0 = poses.get(i);
- var p1 = poses.get(i + 1);
-
- // Make sure the twist is under the tolerance defined by the Spline class.
- var twist = p0.poseMeters.log(p1.poseMeters);
- assertAll(
- () -> assertTrue(Math.abs(twist.dx) < kMaxDx),
- () -> assertTrue(Math.abs(twist.dy) < kMaxDy),
- () -> assertTrue(Math.abs(twist.dtheta) < kMaxDtheta));
- }
-
- // Check first point
- assertAll(
- () -> assertEquals(
- a.getX(), poses.get(0).poseMeters.getX(), 1E-9),
- () -> assertEquals(
- a.getY(), poses.get(0).poseMeters.getY(), 1E-9),
- () -> assertEquals(
- a.getRotation().getRadians(), poses.get(0).poseMeters.getRotation().getRadians(),
- 1E-9));
-
- // Check last point
- assertAll(
- () -> assertEquals(b.getX(), poses.get(poses.size() - 1)
- .poseMeters.getX(), 1E-9),
- () -> assertEquals(b.getY(), poses.get(poses.size() - 1)
- .poseMeters.getY(), 1E-9),
- () -> assertEquals(b.getRotation().getRadians(),
- poses.get(poses.size() - 1).poseMeters.getRotation().getRadians(), 1E-9));
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testStraightLine() {
- run(new Pose2d(), new Pose2d(3, 0, new Rotation2d()));
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testSimpleSCurve() {
- run(new Pose2d(), new Pose2d(1, 1, new Rotation2d()));
- }
-
- @SuppressWarnings("PMD.JUnitTestsShouldIncludeAssert")
- @Test
- void testSquiggly() {
- run(
- new Pose2d(0, 0, Rotation2d.fromDegrees(90)),
- new Pose2d(-1, 0, Rotation2d.fromDegrees(90)));
- }
-
- @Test
- void testMalformed() {
- assertThrows(MalformedSplineException.class,
- () -> run(
- new Pose2d(0, 0, Rotation2d.fromDegrees(0)),
- new Pose2d(1, 0, Rotation2d.fromDegrees(180))));
- assertThrows(MalformedSplineException.class,
- () -> run(
- new Pose2d(10, 10, Rotation2d.fromDegrees(90)),
- new Pose2d(10, 11, Rotation2d.fromDegrees(-90))));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/system/LinearSystemIDTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/system/LinearSystemIDTest.java
deleted file mode 100644
index 8af7ef0..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/system/LinearSystemIDTest.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.system.plant.DCMotor;
-import edu.wpi.first.wpilibj.system.plant.LinearSystemId;
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-class LinearSystemIDTest {
- @Test
- public void testDrivetrainVelocitySystem() {
- var model = LinearSystemId.createDrivetrainVelocitySystem(
- DCMotor.getNEO(4), 70, 0.05, 0.4, 6.0, 6
- );
- assertTrue(model.getA().isEqual(Matrix.mat(Nat.N2(),
- Nat.N2()).fill(-10.14132, 3.06598, 3.06598, -10.14132), 0.001));
-
- assertTrue(model.getB().isEqual(Matrix.mat(Nat.N2(),
- Nat.N2()).fill(4.2590, -1.28762, -1.2876, 4.2590), 0.001));
-
- assertTrue(model.getC().isEqual(Matrix.mat(Nat.N2(),
- Nat.N2()).fill(1.0, 0.0, 0.0, 1.0), 0.001));
-
- assertTrue(model.getD().isEqual(Matrix.mat(Nat.N2(),
- Nat.N2()).fill(0.0, 0.0, 0.0, 0.0), 0.001));
- }
-
- @Test
- public void testElevatorSystem() {
-
- var model = LinearSystemId.createElevatorSystem(DCMotor.getNEO(2), 5, 0.05, 12);
- assertTrue(model.getA().isEqual(Matrix.mat(Nat.N2(),
- Nat.N2()).fill(0, 1, 0, -99.05473), 0.001));
-
- assertTrue(model.getB().isEqual(VecBuilder.fill(0, 20.8), 0.001));
-
- assertTrue(model.getC().isEqual(Matrix.mat(Nat.N1(),
- Nat.N2()).fill(1, 0), 0.001));
-
- assertTrue(model.getD().isEqual(VecBuilder.fill(0), 0.001));
- }
-
- @Test
- public void testFlywheelSystem() {
- var model = LinearSystemId.createFlywheelSystem(DCMotor.getNEO(2), 0.00032, 1.0);
- assertTrue(model.getA().isEqual(VecBuilder.fill(-26.87032), 0.001));
-
- assertTrue(model.getB().isEqual(VecBuilder.fill(1354.166667), 0.001));
-
- assertTrue(model.getC().isEqual(VecBuilder.fill(1), 0.001));
-
- assertTrue(model.getD().isEqual(VecBuilder.fill(0), 0.001));
- }
-
- @Test
- public void testIdentifyPositionSystem() {
- // By controls engineering in frc,
- // x-dot = [0 1 | 0 -kv/ka] x = [0 | 1/ka] u
- var kv = 1.0;
- var ka = 0.5;
- var model = LinearSystemId.identifyPositionSystem(kv, ka);
-
- assertEquals(model.getA(), Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1, 0, -kv / ka));
- assertEquals(model.getB(), VecBuilder.fill(0, 1 / ka));
- }
-
- @Test
- public void testIdentifyVelocitySystem() {
- // By controls engineering in frc,
- // V = kv * velocity + ka * acceleration
- // x-dot = -kv/ka * v + 1/ka \cdot V
- var kv = 1.0;
- var ka = 0.5;
- var model = LinearSystemId.identifyVelocitySystem(kv, ka);
-
- assertEquals(model.getA(), VecBuilder.fill(-kv / ka));
- assertEquals(model.getB(), VecBuilder.fill(1 / ka));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/system/RungeKuttaTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/system/RungeKuttaTest.java
deleted file mode 100644
index de39295..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/system/RungeKuttaTest.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.system;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpiutil.math.Matrix;
-import edu.wpi.first.wpiutil.math.Nat;
-import edu.wpi.first.wpiutil.math.VecBuilder;
-import edu.wpi.first.wpiutil.math.numbers.N1;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-public class RungeKuttaTest {
- @Test
- @SuppressWarnings({"ParameterName", "LocalVariableName"})
- public void testExponential() {
-
- Matrix<N1, N1> y0 = VecBuilder.fill(0.0);
-
- //noinspection SuspiciousNameCombination
- var y1 = RungeKutta.rungeKutta((Matrix<N1, N1> x) -> {
- var y = new Matrix<>(Nat.N1(), Nat.N1());
- y.set(0, 0, Math.exp(x.get(0, 0)));
- return y; },
- y0, 0.1
- );
-
- assertEquals(Math.exp(0.1) - Math.exp(0.0), y1.get(0, 0), 1e-3);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/CentripetalAccelerationConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/CentripetalAccelerationConstraintTest.java
deleted file mode 100644
index 0c2f4f1..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/CentripetalAccelerationConstraintTest.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.Collections;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.trajectory.constraint.CentripetalAccelerationConstraint;
-import edu.wpi.first.wpilibj.util.Units;
-
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-class CentripetalAccelerationConstraintTest {
- @SuppressWarnings("LocalVariableName")
- @Test
- void testCentripetalAccelerationConstraint() {
- double maxCentripetalAcceleration = Units.feetToMeters(7.0); // 7 feet per second squared
- var constraint = new CentripetalAccelerationConstraint(maxCentripetalAcceleration);
-
- Trajectory trajectory = TrajectoryGeneratorTest.getTrajectory(
- Collections.singletonList(constraint));
-
- var duration = trajectory.getTotalTimeSeconds();
- var t = 0.0;
- var dt = 0.02;
-
- while (t < duration) {
- var point = trajectory.sample(t);
- var centripetalAcceleration
- = Math.pow(point.velocityMetersPerSecond, 2) * point.curvatureRadPerMeter;
-
- t += dt;
- assertTrue(centripetalAcceleration <= maxCentripetalAcceleration + 0.05);
- }
- }
-
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/DifferentialDriveKinematicsConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/DifferentialDriveKinematicsConstraintTest.java
deleted file mode 100644
index 6ed9966..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/DifferentialDriveKinematicsConstraintTest.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.Collections;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.kinematics.ChassisSpeeds;
-import edu.wpi.first.wpilibj.kinematics.DifferentialDriveKinematics;
-import edu.wpi.first.wpilibj.trajectory.constraint.DifferentialDriveKinematicsConstraint;
-import edu.wpi.first.wpilibj.util.Units;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-class DifferentialDriveKinematicsConstraintTest {
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
- @Test
- void testDifferentialDriveKinematicsConstraint() {
- double maxVelocity = Units.feetToMeters(12.0); // 12 feet per second
- var kinematics = new DifferentialDriveKinematics(Units.inchesToMeters(27));
- var constraint = new DifferentialDriveKinematicsConstraint(kinematics, maxVelocity);
-
- Trajectory trajectory = TrajectoryGeneratorTest.getTrajectory(
- Collections.singletonList(constraint));
-
- var duration = trajectory.getTotalTimeSeconds();
- var t = 0.0;
- var dt = 0.02;
-
- while (t < duration) {
- var point = trajectory.sample(t);
- var chassisSpeeds = new ChassisSpeeds(
- point.velocityMetersPerSecond, 0,
- point.velocityMetersPerSecond * point.curvatureRadPerMeter
- );
-
- var wheelSpeeds = kinematics.toWheelSpeeds(chassisSpeeds);
-
- t += dt;
- assertAll(
- () -> assertTrue(wheelSpeeds.leftMetersPerSecond <= maxVelocity + 0.05),
- () -> assertTrue(wheelSpeeds.rightMetersPerSecond <= maxVelocity + 0.05)
- );
- }
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/DifferentialDriveVoltageConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/DifferentialDriveVoltageConstraintTest.java
deleted file mode 100644
index fff4c61..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/DifferentialDriveVoltageConstraintTest.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.ArrayList;
-import java.util.Collections;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.controller.SimpleMotorFeedforward;
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-import edu.wpi.first.wpilibj.kinematics.ChassisSpeeds;
-import edu.wpi.first.wpilibj.kinematics.DifferentialDriveKinematics;
-import edu.wpi.first.wpilibj.trajectory.constraint.DifferentialDriveVoltageConstraint;
-
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-class DifferentialDriveVoltageConstraintTest {
- @SuppressWarnings({"LocalVariableName", "PMD.AvoidInstantiatingObjectsInLoops"})
- @Test
- void testDifferentialDriveVoltageConstraint() {
- // Pick an unreasonably large kA to ensure the constraint has to do some work
- var feedforward = new SimpleMotorFeedforward(1, 1, 3);
- var kinematics = new DifferentialDriveKinematics(0.5);
- double maxVoltage = 10;
- var constraint = new DifferentialDriveVoltageConstraint(feedforward,
- kinematics,
- maxVoltage);
-
- Trajectory trajectory = TrajectoryGeneratorTest.getTrajectory(
- Collections.singletonList(constraint));
-
- var duration = trajectory.getTotalTimeSeconds();
- var t = 0.0;
- var dt = 0.02;
-
- while (t < duration) {
- var point = trajectory.sample(t);
- var chassisSpeeds = new ChassisSpeeds(
- point.velocityMetersPerSecond, 0,
- point.velocityMetersPerSecond * point.curvatureRadPerMeter
- );
-
- var wheelSpeeds = kinematics.toWheelSpeeds(chassisSpeeds);
-
- t += dt;
-
- // Not really a strictly-correct test as we're using the chassis accel instead of the
- // wheel accel, but much easier than doing it "properly" and a reasonable check anyway
- assertAll(
- () -> assertTrue(feedforward.calculate(wheelSpeeds.leftMetersPerSecond,
- point.accelerationMetersPerSecondSq)
- <= maxVoltage + 0.05),
- () -> assertTrue(feedforward.calculate(wheelSpeeds.leftMetersPerSecond,
- point.accelerationMetersPerSecondSq)
- >= -maxVoltage - 0.05),
- () -> assertTrue(feedforward.calculate(wheelSpeeds.rightMetersPerSecond,
- point.accelerationMetersPerSecondSq)
- <= maxVoltage + 0.05),
- () -> assertTrue(feedforward.calculate(wheelSpeeds.rightMetersPerSecond,
- point.accelerationMetersPerSecondSq)
- >= -maxVoltage - 0.05)
- );
- }
- }
-
- @Test
- void testEndpointHighCurvature() {
- var feedforward = new SimpleMotorFeedforward(1, 1, 3);
-
- // Large trackwidth - need to test with radius of curvature less than half of trackwidth
- var kinematics = new DifferentialDriveKinematics(3);
- double maxVoltage = 10;
- var constraint = new DifferentialDriveVoltageConstraint(feedforward,
- kinematics,
- maxVoltage);
-
- var config = new TrajectoryConfig(12, 12).addConstraint(constraint);
-
- // Radius of curvature should be ~1 meter.
- assertDoesNotThrow(() -> TrajectoryGenerator.generateTrajectory(
- new Pose2d(1, 0, Rotation2d.fromDegrees(90)),
- new ArrayList<Translation2d>(),
- new Pose2d(0, 1, Rotation2d.fromDegrees(180)),
- config));
-
- assertDoesNotThrow(() -> TrajectoryGenerator.generateTrajectory(
- new Pose2d(0, 1, Rotation2d.fromDegrees(180)),
- new ArrayList<Translation2d>(),
- new Pose2d(1, 0, Rotation2d.fromDegrees(90)),
- config.setReversed(true)));
-
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/EllipticalRegionConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/EllipticalRegionConstraintTest.java
deleted file mode 100644
index 513db06..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/EllipticalRegionConstraintTest.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-import edu.wpi.first.wpilibj.trajectory.constraint.EllipticalRegionConstraint;
-import edu.wpi.first.wpilibj.trajectory.constraint.MaxVelocityConstraint;
-import edu.wpi.first.wpilibj.util.Units;
-
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class EllipticalRegionConstraintTest {
- @Test
- void testConstraint() {
- // Create constraints
- double maxVelocity = Units.feetToMeters(3.0);
- var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
- var regionConstraint = new EllipticalRegionConstraint(
- new Translation2d(Units.feetToMeters(5.0), Units.feetToMeters(5.0)),
- Units.feetToMeters(10.0), Units.feetToMeters(5.0), Rotation2d.fromDegrees(180.0),
- maxVelocityConstraint
- );
-
- // Get trajectory
- var trajectory = TrajectoryGeneratorTest.getTrajectory(List.of(regionConstraint));
-
- // Iterate through trajectory and check constraints
- boolean exceededConstraintOutsideRegion = false;
- for (var point : trajectory.getStates()) {
- var translation = point.poseMeters.getTranslation();
-
- if (translation.getX() < Units.feetToMeters(10)
- && translation.getY() < Units.feetToMeters(5)) {
- assertTrue(Math.abs(point.velocityMetersPerSecond) < maxVelocity + 0.05);
- } else if (Math.abs(point.velocityMetersPerSecond) >= maxVelocity + 0.05) {
- exceededConstraintOutsideRegion = true;
- }
- }
- assertTrue(exceededConstraintOutsideRegion);
- }
-
- @Test
- void testIsPoseWithinRegion() {
- double maxVelocity = Units.feetToMeters(3.0);
- var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
-
- var regionConstraintNoRotation = new EllipticalRegionConstraint(
- new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
- Units.feetToMeters(2.0), Units.feetToMeters(4.0), new Rotation2d(),
- maxVelocityConstraint);
-
- assertFalse(regionConstraintNoRotation.isPoseInRegion(new Pose2d(
- Units.feetToMeters(2.1), Units.feetToMeters(1.0), new Rotation2d()
- )));
-
- var regionConstraintWithRotation = new EllipticalRegionConstraint(
- new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
- Units.feetToMeters(2.0), Units.feetToMeters(4.0), Rotation2d.fromDegrees(90.0),
- maxVelocityConstraint);
-
- assertTrue(regionConstraintWithRotation.isPoseInRegion(new Pose2d(
- Units.feetToMeters(2.1), Units.feetToMeters(1.0), new Rotation2d()
- )));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/RectangularRegionConstraintTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/RectangularRegionConstraintTest.java
deleted file mode 100644
index 94eeb35..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/RectangularRegionConstraintTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-import edu.wpi.first.wpilibj.trajectory.constraint.MaxVelocityConstraint;
-import edu.wpi.first.wpilibj.trajectory.constraint.RectangularRegionConstraint;
-import edu.wpi.first.wpilibj.util.Units;
-
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class RectangularRegionConstraintTest {
- @Test
- void testConstraint() {
- // Create constraints
- double maxVelocity = Units.feetToMeters(3.0);
- var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
- var regionConstraint = new RectangularRegionConstraint(
- new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
- new Translation2d(Units.feetToMeters(7.0), Units.feetToMeters(27.0)),
- maxVelocityConstraint
- );
-
- // Get trajectory
- var trajectory = TrajectoryGeneratorTest.getTrajectory(List.of(regionConstraint));
-
- // Iterate through trajectory and check constraints
- boolean exceededConstraintOutsideRegion = false;
- for (var point : trajectory.getStates()) {
- if (regionConstraint.isPoseInRegion(point.poseMeters)) {
- assertTrue(Math.abs(point.velocityMetersPerSecond) < maxVelocity + 0.05);
- } else if (Math.abs(point.velocityMetersPerSecond) >= maxVelocity + 0.05) {
- exceededConstraintOutsideRegion = true;
- }
- }
- assertTrue(exceededConstraintOutsideRegion);
- }
-
- @Test
- void testIsPoseWithinRegion() {
- double maxVelocity = Units.feetToMeters(3.0);
- var maxVelocityConstraint = new MaxVelocityConstraint(maxVelocity);
- var regionConstraint = new RectangularRegionConstraint(
- new Translation2d(Units.feetToMeters(1.0), Units.feetToMeters(1.0)),
- new Translation2d(Units.feetToMeters(7.0), Units.feetToMeters(27.0)),
- maxVelocityConstraint
- );
-
- assertFalse(regionConstraint.isPoseInRegion(new Pose2d()));
- assertTrue(regionConstraint.isPoseInRegion(new Pose2d(Units.feetToMeters(3.0),
- Units.feetToMeters(14.5), new Rotation2d())));
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryGeneratorTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryGeneratorTest.java
deleted file mode 100644
index 2bfa972..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryGeneratorTest.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Transform2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-import edu.wpi.first.wpilibj.trajectory.constraint.TrajectoryConstraint;
-
-import static edu.wpi.first.wpilibj.util.Units.feetToMeters;
-import static org.junit.jupiter.api.Assertions.assertAll;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-class TrajectoryGeneratorTest {
- static Trajectory getTrajectory(List<? extends TrajectoryConstraint> constraints) {
- final double maxVelocity = feetToMeters(12.0);
- final double maxAccel = feetToMeters(12);
-
- // 2018 cross scale auto waypoints.
- var sideStart = new Pose2d(feetToMeters(1.54), feetToMeters(23.23),
- Rotation2d.fromDegrees(-180));
- var crossScale = new Pose2d(feetToMeters(23.7), feetToMeters(6.8),
- Rotation2d.fromDegrees(-160));
-
- var waypoints = new ArrayList<Pose2d>();
- waypoints.add(sideStart);
- waypoints.add(sideStart.plus(
- new Transform2d(new Translation2d(feetToMeters(-13), feetToMeters(0)),
- new Rotation2d())));
- waypoints.add(sideStart.plus(
- new Transform2d(new Translation2d(feetToMeters(-19.5), feetToMeters(5)),
- Rotation2d.fromDegrees(-90))));
- waypoints.add(crossScale);
-
- TrajectoryConfig config = new TrajectoryConfig(maxVelocity, maxAccel)
- .setReversed(true)
- .addConstraints(constraints);
-
- return TrajectoryGenerator.generateTrajectory(waypoints, config);
- }
-
- @Test
- @SuppressWarnings("LocalVariableName")
- void testGenerationAndConstraints() {
- Trajectory trajectory = getTrajectory(new ArrayList<>());
-
- double duration = trajectory.getTotalTimeSeconds();
- double t = 0.0;
- double dt = 0.02;
-
- while (t < duration) {
- var point = trajectory.sample(t);
- t += dt;
- assertAll(
- () -> assertTrue(Math.abs(point.velocityMetersPerSecond) < feetToMeters(12.0) + 0.05),
- () -> assertTrue(Math.abs(point.accelerationMetersPerSecondSq) < feetToMeters(12.0)
- + 0.05)
- );
- }
- }
-
- @Test
- void testMalformedTrajectory() {
- var traj =
- TrajectoryGenerator.generateTrajectory(
- Arrays.asList(
- new Pose2d(0, 0, Rotation2d.fromDegrees(0)),
- new Pose2d(1, 0, Rotation2d.fromDegrees(180))
- ),
- new TrajectoryConfig(feetToMeters(12), feetToMeters(12))
- );
-
- assertEquals(traj.getStates().size(), 1);
- assertEquals(traj.getTotalTimeSeconds(), 0);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryJsonTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryJsonTest.java
deleted file mode 100644
index d8d59b1..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryJsonTest.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.kinematics.DifferentialDriveKinematics;
-import edu.wpi.first.wpilibj.trajectory.constraint.DifferentialDriveKinematicsConstraint;
-
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-public class TrajectoryJsonTest {
- @Test
- void deserializeMatches() {
- var config = List.of(new DifferentialDriveKinematicsConstraint(
- new DifferentialDriveKinematics(20), 3));
- var trajectory = TrajectoryGeneratorTest.getTrajectory(config);
-
- var deserialized =
- assertDoesNotThrow(() ->
- TrajectoryUtil.deserializeTrajectory(TrajectoryUtil.serializeTrajectory(trajectory)));
-
- assertEquals(trajectory.getStates(), deserialized.getStates());
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryTransformTest.java b/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryTransformTest.java
deleted file mode 100644
index b17046b..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpilibj/trajectory/TrajectoryTransformTest.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpilibj.trajectory;
-
-import java.util.List;
-
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpilibj.geometry.Pose2d;
-import edu.wpi.first.wpilibj.geometry.Rotation2d;
-import edu.wpi.first.wpilibj.geometry.Transform2d;
-import edu.wpi.first.wpilibj.geometry.Translation2d;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class TrajectoryTransformTest {
- @Test
- void testTransformBy() {
- var config = new TrajectoryConfig(3, 3);
- var trajectory = TrajectoryGenerator.generateTrajectory(
- new Pose2d(), List.of(), new Pose2d(1, 1, Rotation2d.fromDegrees(90)),
- config
- );
-
- var transformedTrajectory = trajectory.transformBy(
- new Transform2d(new Translation2d(1, 2), Rotation2d.fromDegrees(30)));
-
- // Test initial pose.
- assertEquals(new Pose2d(1, 2, Rotation2d.fromDegrees(30)),
- transformedTrajectory.sample(0).poseMeters);
-
- testSameShapedTrajectory(trajectory.getStates(), transformedTrajectory.getStates());
- }
-
- @Test
- void testRelativeTo() {
- var config = new TrajectoryConfig(3, 3);
- var trajectory = TrajectoryGenerator.generateTrajectory(
- new Pose2d(1, 2, Rotation2d.fromDegrees(30.0)),
- List.of(), new Pose2d(5, 7, Rotation2d.fromDegrees(90)),
- config
- );
-
- var transformedTrajectory = trajectory.relativeTo(new Pose2d(1, 2, Rotation2d.fromDegrees(30)));
-
- // Test initial pose.
- assertEquals(new Pose2d(), transformedTrajectory.sample(0).poseMeters);
-
- testSameShapedTrajectory(trajectory.getStates(), transformedTrajectory.getStates());
- }
-
- void testSameShapedTrajectory(List<Trajectory.State> statesA, List<Trajectory.State> statesB) {
- for (int i = 0; i < statesA.size() - 1; i++) {
- var a1 = statesA.get(i).poseMeters;
- var a2 = statesA.get(i + 1).poseMeters;
-
- var b1 = statesB.get(i).poseMeters;
- var b2 = statesB.get(i + 1).poseMeters;
-
- assertEquals(a2.relativeTo(a1), b2.relativeTo(b1));
- }
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpiutil/math/MathUtilTest.java b/wpimath/src/test/java/edu/wpi/first/wpiutil/math/MathUtilTest.java
deleted file mode 100644
index 14a0f7c..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpiutil/math/MathUtilTest.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-class MathUtilTest {
- @Test
- void testAngleNormalize() {
- assertEquals(MathUtil.normalizeAngle(5 * Math.PI), Math.PI);
- assertEquals(MathUtil.normalizeAngle(-5 * Math.PI), Math.PI);
- assertEquals(MathUtil.normalizeAngle(Math.PI / 2), Math.PI / 2);
- assertEquals(MathUtil.normalizeAngle(-Math.PI / 2), -Math.PI / 2);
- }
-}
diff --git a/wpimath/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java b/wpimath/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java
deleted file mode 100644
index 1af0625..0000000
--- a/wpimath/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java
+++ /dev/null
@@ -1,174 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-package edu.wpi.first.wpiutil.math;
-
-import org.ejml.data.SingularMatrixException;
-import org.junit.jupiter.api.Test;
-
-import edu.wpi.first.wpiutil.math.numbers.N1;
-import edu.wpi.first.wpiutil.math.numbers.N2;
-import edu.wpi.first.wpiutil.math.numbers.N3;
-import edu.wpi.first.wpiutil.math.numbers.N4;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class MatrixTest {
- @Test
- void testMatrixMultiplication() {
- var mat1 = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(2.0, 1.0,
- 0.0, 1.0);
- var mat2 = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(3.0, 0.0,
- 0.0, 2.5);
-
- Matrix<N2, N2> result = mat1.times(mat2);
-
- assertEquals(result, Matrix.mat(Nat.N2(), Nat.N2()).fill(6.0, 2.5, 0.0, 2.5));
-
- var mat3 = Matrix.mat(Nat.N2(), Nat.N3())
- .fill(1.0, 3.0, 0.5,
- 2.0, 4.3, 1.2);
- var mat4 = Matrix.mat(Nat.N3(), Nat.N4())
- .fill(3.0, 1.5, 2.0, 4.5,
- 2.3, 1.0, 1.6, 3.1,
- 5.2, 2.1, 2.0, 1.0);
-
- Matrix<N2, N4> result2 = mat3.times(mat4);
-
- assertTrue(Matrix.mat(Nat.N2(), Nat.N4())
- .fill(12.5, 5.55, 7.8, 14.3,
- 22.13, 9.82, 13.28, 23.53).isEqual(
- result2,
- 1E-9
- ));
- }
-
- @Test
- void testMatrixVectorMultiplication() {
- var mat = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(1.0, 1.0,
- 0.0, 1.0);
-
- var vec = VecBuilder.fill(3.0, 2.0);
-
- Matrix<N2, N1> result = mat.times(vec);
- assertEquals(VecBuilder.fill(5.0, 2.0), result);
- }
-
- @Test
- void testTranspose() {
- Matrix<N3, N1> vec = VecBuilder
- .fill(1.0,
- 2.0,
- 3.0);
-
- Matrix<N1, N3> transpose = vec.transpose();
-
- assertEquals(Matrix.mat(Nat.N1(), Nat.N3()).fill(1.0, 2.0, 3.0), transpose);
- }
-
- @Test
- void testSolve() {
- var mat1 = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 2.0, 3.0, 4.0);
- var vec1 = VecBuilder.fill(1.0, 2.0);
-
- var solve1 = mat1.solve(vec1);
-
- assertEquals(VecBuilder.fill(0.0, 0.5), solve1);
-
- var mat2 = Matrix.mat(Nat.N3(), Nat.N2()).fill(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
- var vec2 = VecBuilder.fill(1.0, 2.0, 3.0);
-
- var solve2 = mat2.solve(vec2);
-
- assertEquals(VecBuilder.fill(0.0, 0.5), solve2);
- }
-
- @Test
- void testInverse() {
- var mat = Matrix.mat(Nat.N3(), Nat.N3())
- .fill(1.0, 3.0, 2.0,
- 5.0, 2.0, 1.5,
- 0.0, 1.3, 2.5);
-
- var inv = mat.inv();
-
- assertTrue(Matrix.eye(Nat.N3()).isEqual(
- mat.times(inv),
- 1E-9
- ));
-
- assertTrue(Matrix.eye(Nat.N3()).isEqual(
- inv.times(mat),
- 1E-9
- ));
- }
-
- @Test
- void testUninvertableMatrix() {
- var singularMatrix = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(2.0, 1.0,
- 2.0, 1.0);
-
- assertThrows(SingularMatrixException.class, singularMatrix::inv);
- }
-
- @Test
- void testMatrixScalarArithmetic() {
- var mat = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(1.0, 2.0,
- 3.0, 4.0);
-
- assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(3.0, 4.0, 5.0, 6.0), mat.plus(2.0));
-
- assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 1.0, 2.0, 3.0), mat.minus(1.0));
-
- assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(2.0, 4.0, 6.0, 8.0), mat.times(2.0));
-
- assertTrue(Matrix.mat(Nat.N2(), Nat.N2()).fill(0.5, 1.0, 1.5, 2.0).isEqual(
- mat.div(2.0),
- 1E-3
- ));
- }
-
- @Test
- void testMatrixMatrixArithmetic() {
- var mat1 = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(1.0, 2.0,
- 3.0, 4.0);
-
- var mat2 = Matrix.mat(Nat.N2(), Nat.N2())
- .fill(5.0, 6.0,
- 7.0, 8.0);
-
- assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(-4.0, -4.0, -4.0, -4.0),
- mat1.minus(mat2)
- );
-
- assertEquals(Matrix.mat(Nat.N2(), Nat.N2()).fill(6.0, 8.0, 10.0, 12.0),
- mat1.plus(mat2)
- );
- }
-
- @Test
- void testMatrixExponential() {
- var matrix = Matrix.eye(Nat.N2());
- var result = matrix.exp();
-
- assertTrue(result.isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(Math.E, 0, 0, Math.E), 1E-9));
-
- matrix = Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 2, 3, 4);
- result = matrix.times(0.01).exp();
-
- assertTrue(result.isEqual(Matrix.mat(Nat.N2(), Nat.N2()).fill(1.01035625, 0.02050912,
- 0.03076368, 1.04111993), 1E-8));
- }
-}
diff --git a/wpimath/src/test/native/cpp/EigenTest.cpp b/wpimath/src/test/native/cpp/EigenTest.cpp
index 2065e26..c1786c3 100644
--- a/wpimath/src/test/native/cpp/EigenTest.cpp
+++ b/wpimath/src/test/native/cpp/EigenTest.cpp
@@ -1,57 +1,46 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "Eigen/Core"
#include "Eigen/LU"
#include "gtest/gtest.h"
-TEST(EigenTest, MultiplicationTest) {
- Eigen::Matrix<double, 2, 2> m1;
- m1 << 2, 1, 0, 1;
-
- Eigen::Matrix<double, 2, 2> m2;
- m2 << 3, 0, 0, 2.5;
+TEST(EigenTest, Multiplication) {
+ Eigen::Matrix<double, 2, 2> m1{{2, 1}, {0, 1}};
+ Eigen::Matrix<double, 2, 2> m2{{3, 0}, {0, 2.5}};
const auto result = m1 * m2;
- Eigen::Matrix<double, 2, 2> expectedResult;
- expectedResult << 6.0, 2.5, 0.0, 2.5;
+ Eigen::Matrix<double, 2, 2> expectedResult{{6.0, 2.5}, {0.0, 2.5}};
EXPECT_TRUE(expectedResult.isApprox(result));
- Eigen::Matrix<double, 2, 3> m3;
- m3 << 1.0, 3.0, 0.5, 2.0, 4.3, 1.2;
-
- Eigen::Matrix<double, 3, 4> m4;
- m4 << 3.0, 1.5, 2.0, 4.5, 2.3, 1.0, 1.6, 3.1, 5.2, 2.1, 2.0, 1.0;
+ Eigen::Matrix<double, 2, 3> m3{{1.0, 3.0, 0.5}, {2.0, 4.3, 1.2}};
+ Eigen::Matrix<double, 3, 4> m4{
+ {3.0, 1.5, 2.0, 4.5}, {2.3, 1.0, 1.6, 3.1}, {5.2, 2.1, 2.0, 1.0}};
const auto result2 = m3 * m4;
- Eigen::Matrix<double, 2, 4> expectedResult2;
- expectedResult2 << 12.5, 5.55, 7.8, 14.3, 22.13, 9.82, 13.28, 23.53;
+ Eigen::Matrix<double, 2, 4> expectedResult2{{12.5, 5.55, 7.8, 14.3},
+ {22.13, 9.82, 13.28, 23.53}};
EXPECT_TRUE(expectedResult2.isApprox(result2));
}
-TEST(EigenTest, TransposeTest) {
- Eigen::Matrix<double, 3, 1> vec;
- vec << 1, 2, 3;
+TEST(EigenTest, Transpose) {
+ Eigen::Vector<double, 3> vec{1, 2, 3};
const auto transpose = vec.transpose();
- Eigen::Matrix<double, 1, 3> expectedTranspose;
- expectedTranspose << 1, 2, 3;
+ Eigen::RowVector<double, 3> expectedTranspose{1, 2, 3};
EXPECT_TRUE(expectedTranspose.isApprox(transpose));
}
-TEST(EigenTest, InverseTest) {
- Eigen::Matrix<double, 3, 3> mat;
- mat << 1.0, 3.0, 2.0, 5.0, 2.0, 1.5, 0.0, 1.3, 2.5;
+TEST(EigenTest, Inverse) {
+ Eigen::Matrix<double, 3, 3> mat{
+ {1.0, 3.0, 2.0}, {5.0, 2.0, 1.5}, {0.0, 1.3, 2.5}};
const auto inverse = mat.inverse();
const auto identity = Eigen::MatrixXd::Identity(3, 3);
diff --git a/wpimath/src/test/native/cpp/FormatterTest.cpp b/wpimath/src/test/native/cpp/FormatterTest.cpp
new file mode 100644
index 0000000..cd7ef5c
--- /dev/null
+++ b/wpimath/src/test/native/cpp/FormatterTest.cpp
@@ -0,0 +1,23 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <fmt/format.h>
+
+#include "frc/fmt/Eigen.h"
+#include "frc/fmt/Units.h"
+#include "gtest/gtest.h"
+#include "units/velocity.h"
+
+TEST(FormatterTest, Eigen) {
+ Eigen::Matrix<double, 3, 2> A{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
+ EXPECT_EQ(
+ " 1.000000 2.000000\n"
+ " 3.000000 4.000000\n"
+ " 5.000000 6.000000",
+ fmt::format("{}", A));
+}
+
+TEST(FormatterTest, Units) {
+ EXPECT_EQ("4 mps", fmt::format("{}", 4_mps));
+}
diff --git a/wpimath/src/test/native/cpp/LinearFilterNoiseTest.cpp b/wpimath/src/test/native/cpp/LinearFilterNoiseTest.cpp
deleted file mode 100644
index 934e14b..0000000
--- a/wpimath/src/test/native/cpp/LinearFilterNoiseTest.cpp
+++ /dev/null
@@ -1,94 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2015-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-#include "frc/LinearFilter.h" // NOLINT(build/include_order)
-
-#include <cmath>
-#include <memory>
-#include <random>
-
-#include <wpi/math>
-
-#include "gtest/gtest.h"
-#include "units/time.h"
-
-// Filter constants
-static constexpr units::second_t kFilterStep = 0.005_s;
-static constexpr units::second_t kFilterTime = 2.0_s;
-static constexpr double kSinglePoleIIRTimeConstant = 0.015915;
-static constexpr int32_t kMovAvgTaps = 6;
-
-enum LinearFilterNoiseTestType { TEST_SINGLE_POLE_IIR, TEST_MOVAVG };
-
-std::ostream& operator<<(std::ostream& os,
- const LinearFilterNoiseTestType& type) {
- switch (type) {
- case TEST_SINGLE_POLE_IIR:
- os << "LinearFilter SinglePoleIIR";
- break;
- case TEST_MOVAVG:
- os << "LinearFilter MovingAverage";
- break;
- }
-
- return os;
-}
-
-static double GetData(double t) {
- return 100.0 * std::sin(2.0 * wpi::math::pi * t);
-}
-
-class LinearFilterNoiseTest
- : public testing::TestWithParam<LinearFilterNoiseTestType> {
- protected:
- std::unique_ptr<frc::LinearFilter<double>> m_filter;
-
- void SetUp() override {
- switch (GetParam()) {
- case TEST_SINGLE_POLE_IIR: {
- m_filter = std::make_unique<frc::LinearFilter<double>>(
- frc::LinearFilter<double>::SinglePoleIIR(kSinglePoleIIRTimeConstant,
- kFilterStep));
- break;
- }
-
- case TEST_MOVAVG: {
- m_filter = std::make_unique<frc::LinearFilter<double>>(
- frc::LinearFilter<double>::MovingAverage(kMovAvgTaps));
- break;
- }
- }
- }
-};
-
-/**
- * Test if the filter reduces the noise produced by a signal generator
- */
-TEST_P(LinearFilterNoiseTest, NoiseReduce) {
- double noiseGenError = 0.0;
- double filterError = 0.0;
-
- std::random_device rd;
- std::mt19937 gen{rd()};
- std::normal_distribution<double> distr{0.0, 10.0};
-
- for (auto t = 0_s; t < kFilterTime; t += kFilterStep) {
- double theory = GetData(t.to<double>());
- double noise = distr(gen);
- filterError += std::abs(m_filter->Calculate(theory + noise) - theory);
- noiseGenError += std::abs(noise - theory);
- }
-
- RecordProperty("FilterError", filterError);
-
- // The filter should have produced values closer to the theory
- EXPECT_GT(noiseGenError, filterError)
- << "Filter should have reduced noise accumulation but failed";
-}
-
-INSTANTIATE_TEST_SUITE_P(Test, LinearFilterNoiseTest,
- testing::Values(TEST_SINGLE_POLE_IIR, TEST_MOVAVG));
diff --git a/wpimath/src/test/native/cpp/LinearFilterOutputTest.cpp b/wpimath/src/test/native/cpp/LinearFilterOutputTest.cpp
deleted file mode 100644
index d321518..0000000
--- a/wpimath/src/test/native/cpp/LinearFilterOutputTest.cpp
+++ /dev/null
@@ -1,136 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2015-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-#include "frc/LinearFilter.h" // NOLINT(build/include_order)
-
-#include <cmath>
-#include <functional>
-#include <memory>
-#include <random>
-
-#include <wpi/math>
-
-#include "gtest/gtest.h"
-#include "units/time.h"
-
-// Filter constants
-static constexpr units::second_t kFilterStep = 0.005_s;
-static constexpr units::second_t kFilterTime = 2.0_s;
-static constexpr double kSinglePoleIIRTimeConstant = 0.015915;
-static constexpr double kSinglePoleIIRExpectedOutput = -3.2172003;
-static constexpr double kHighPassTimeConstant = 0.006631;
-static constexpr double kHighPassExpectedOutput = 10.074717;
-static constexpr int32_t kMovAvgTaps = 6;
-static constexpr double kMovAvgExpectedOutput = -10.191644;
-
-enum LinearFilterOutputTestType {
- TEST_SINGLE_POLE_IIR,
- TEST_HIGH_PASS,
- TEST_MOVAVG,
- TEST_PULSE
-};
-
-std::ostream& operator<<(std::ostream& os,
- const LinearFilterOutputTestType& type) {
- switch (type) {
- case TEST_SINGLE_POLE_IIR:
- os << "LinearFilter SinglePoleIIR";
- break;
- case TEST_HIGH_PASS:
- os << "LinearFilter HighPass";
- break;
- case TEST_MOVAVG:
- os << "LinearFilter MovingAverage";
- break;
- case TEST_PULSE:
- os << "LinearFilter Pulse";
- break;
- }
-
- return os;
-}
-
-static double GetData(double t) {
- return 100.0 * std::sin(2.0 * wpi::math::pi * t) +
- 20.0 * std::cos(50.0 * wpi::math::pi * t);
-}
-
-static double GetPulseData(double t) {
- if (std::abs(t - 1.0) < 0.001) {
- return 1.0;
- } else {
- return 0.0;
- }
-}
-
-/**
- * A fixture that includes a consistent data source wrapped in a filter
- */
-class LinearFilterOutputTest
- : public testing::TestWithParam<LinearFilterOutputTestType> {
- protected:
- std::unique_ptr<frc::LinearFilter<double>> m_filter;
- std::function<double(double)> m_data;
- double m_expectedOutput = 0.0;
-
- void SetUp() override {
- switch (GetParam()) {
- case TEST_SINGLE_POLE_IIR: {
- m_filter = std::make_unique<frc::LinearFilter<double>>(
- frc::LinearFilter<double>::SinglePoleIIR(kSinglePoleIIRTimeConstant,
- kFilterStep));
- m_data = GetData;
- m_expectedOutput = kSinglePoleIIRExpectedOutput;
- break;
- }
-
- case TEST_HIGH_PASS: {
- m_filter = std::make_unique<frc::LinearFilter<double>>(
- frc::LinearFilter<double>::HighPass(kHighPassTimeConstant,
- kFilterStep));
- m_data = GetData;
- m_expectedOutput = kHighPassExpectedOutput;
- break;
- }
-
- case TEST_MOVAVG: {
- m_filter = std::make_unique<frc::LinearFilter<double>>(
- frc::LinearFilter<double>::MovingAverage(kMovAvgTaps));
- m_data = GetData;
- m_expectedOutput = kMovAvgExpectedOutput;
- break;
- }
-
- case TEST_PULSE: {
- m_filter = std::make_unique<frc::LinearFilter<double>>(
- frc::LinearFilter<double>::MovingAverage(kMovAvgTaps));
- m_data = GetPulseData;
- m_expectedOutput = 0.0;
- break;
- }
- }
- }
-};
-
-/**
- * Test if the linear filters produce consistent output for a given data set.
- */
-TEST_P(LinearFilterOutputTest, Output) {
- double filterOutput = 0.0;
- for (auto t = 0_s; t < kFilterTime; t += kFilterStep) {
- filterOutput = m_filter->Calculate(m_data(t.to<double>()));
- }
-
- RecordProperty("LinearFilterOutput", filterOutput);
-
- EXPECT_FLOAT_EQ(m_expectedOutput, filterOutput)
- << "Filter output didn't match expected value";
-}
-
-INSTANTIATE_TEST_SUITE_P(Test, LinearFilterOutputTest,
- testing::Values(TEST_SINGLE_POLE_IIR, TEST_HIGH_PASS,
- TEST_MOVAVG, TEST_PULSE));
diff --git a/wpimath/src/test/native/cpp/MathUtilTest.cpp b/wpimath/src/test/native/cpp/MathUtilTest.cpp
new file mode 100644
index 0000000..6b5af2b
--- /dev/null
+++ b/wpimath/src/test/native/cpp/MathUtilTest.cpp
@@ -0,0 +1,90 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/MathUtil.h"
+#include "gtest/gtest.h"
+#include "units/angle.h"
+
+#define EXPECT_UNITS_EQ(a, b) EXPECT_DOUBLE_EQ((a).value(), (b).value())
+
+#define EXPECT_UNITS_NEAR(a, b, c) EXPECT_NEAR((a).value(), (b).value(), c)
+
+TEST(MathUtilTest, ApplyDeadband) {
+ // < 0
+ EXPECT_DOUBLE_EQ(-1.0, frc::ApplyDeadband(-1.0, 0.02));
+ EXPECT_DOUBLE_EQ((-0.03 + 0.02) / (1.0 - 0.02),
+ frc::ApplyDeadband(-0.03, 0.02));
+ EXPECT_DOUBLE_EQ(0.0, frc::ApplyDeadband(-0.02, 0.02));
+ EXPECT_DOUBLE_EQ(0.0, frc::ApplyDeadband(-0.01, 0.02));
+
+ // == 0
+ EXPECT_DOUBLE_EQ(0.0, frc::ApplyDeadband(0.0, 0.02));
+
+ // > 0
+ EXPECT_DOUBLE_EQ(0.0, frc::ApplyDeadband(0.01, 0.02));
+ EXPECT_DOUBLE_EQ(0.0, frc::ApplyDeadband(0.02, 0.02));
+ EXPECT_DOUBLE_EQ((0.03 - 0.02) / (1.0 - 0.02),
+ frc::ApplyDeadband(0.03, 0.02));
+ EXPECT_DOUBLE_EQ(1.0, frc::ApplyDeadband(1.0, 0.02));
+}
+
+TEST(MathUtilTest, InputModulus) {
+ // These tests check error wrapping. That is, the result of wrapping the
+ // result of an angle reference minus the measurement.
+
+ // Test symmetric range
+ EXPECT_DOUBLE_EQ(-20.0, frc::InputModulus(170.0 - (-170.0), -180.0, 180.0));
+ EXPECT_DOUBLE_EQ(-20.0,
+ frc::InputModulus(170.0 + 360.0 - (-170.0), -180.0, 180.0));
+ EXPECT_DOUBLE_EQ(-20.0,
+ frc::InputModulus(170.0 - (-170.0 + 360.0), -180.0, 180.0));
+ EXPECT_DOUBLE_EQ(20.0, frc::InputModulus(-170.0 - 170.0, -180.0, 180.0));
+ EXPECT_DOUBLE_EQ(20.0,
+ frc::InputModulus(-170.0 + 360.0 - 170.0, -180.0, 180.0));
+ EXPECT_DOUBLE_EQ(20.0,
+ frc::InputModulus(-170.0 - (170.0 + 360.0), -180.0, 180.0));
+
+ // Test range starting at zero
+ EXPECT_DOUBLE_EQ(340.0, frc::InputModulus(170.0 - 190.0, 0.0, 360.0));
+ EXPECT_DOUBLE_EQ(340.0, frc::InputModulus(170.0 + 360.0 - 190.0, 0.0, 360.0));
+ EXPECT_DOUBLE_EQ(340.0,
+ frc::InputModulus(170.0 - (190.0 + 360.0), 0.0, 360.0));
+
+ // Test asymmetric range that doesn't start at zero
+ EXPECT_DOUBLE_EQ(-20.0, frc::InputModulus(170.0 - (-170.0), -170.0, 190.0));
+
+ // Test range with both positive endpoints
+ EXPECT_DOUBLE_EQ(2.0, frc::InputModulus(0.0, 1.0, 3.0));
+ EXPECT_DOUBLE_EQ(3.0, frc::InputModulus(1.0, 1.0, 3.0));
+ EXPECT_DOUBLE_EQ(2.0, frc::InputModulus(2.0, 1.0, 3.0));
+ EXPECT_DOUBLE_EQ(3.0, frc::InputModulus(3.0, 1.0, 3.0));
+ EXPECT_DOUBLE_EQ(2.0, frc::InputModulus(4.0, 1.0, 3.0));
+
+ // Test all supported types
+ EXPECT_DOUBLE_EQ(-20.0,
+ frc::InputModulus<double>(170.0 - (-170.0), -170.0, 190.0));
+ EXPECT_EQ(-20, frc::InputModulus<int>(170 - (-170), -170, 190));
+ EXPECT_EQ(-20_deg, frc::InputModulus<units::degree_t>(170_deg - (-170_deg),
+ -170_deg, 190_deg));
+}
+
+TEST(MathUtilTest, AngleModulus) {
+ EXPECT_UNITS_NEAR(
+ frc::AngleModulus(units::radian_t{-2000 * wpi::numbers::pi / 180}),
+ units::radian_t{160 * wpi::numbers::pi / 180}, 1e-10);
+ EXPECT_UNITS_NEAR(
+ frc::AngleModulus(units::radian_t{358 * wpi::numbers::pi / 180}),
+ units::radian_t{-2 * wpi::numbers::pi / 180}, 1e-10);
+ EXPECT_UNITS_NEAR(frc::AngleModulus(units::radian_t{2.0 * wpi::numbers::pi}),
+ 0_rad, 1e-10);
+
+ EXPECT_UNITS_EQ(frc::AngleModulus(units::radian_t(5 * wpi::numbers::pi)),
+ units::radian_t(wpi::numbers::pi));
+ EXPECT_UNITS_EQ(frc::AngleModulus(units::radian_t(-5 * wpi::numbers::pi)),
+ units::radian_t(wpi::numbers::pi));
+ EXPECT_UNITS_EQ(frc::AngleModulus(units::radian_t(wpi::numbers::pi / 2)),
+ units::radian_t(wpi::numbers::pi / 2));
+ EXPECT_UNITS_EQ(frc::AngleModulus(units::radian_t(-wpi::numbers::pi / 2)),
+ units::radian_t(-wpi::numbers::pi / 2));
+}
diff --git a/wpimath/src/test/native/cpp/StateSpaceTest.cpp b/wpimath/src/test/native/cpp/StateSpaceTest.cpp
index a1600fc..0dc3978 100644
--- a/wpimath/src/test/native/cpp/StateSpaceTest.cpp
+++ b/wpimath/src/test/native/cpp/StateSpaceTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -25,7 +22,7 @@
constexpr double kPositionStddev = 0.0001;
constexpr auto kDt = 0.00505_s;
-class StateSpace : public testing::Test {
+class StateSpaceTest : public testing::Test {
public:
LinearSystem<2, 1, 1> plant = [] {
auto motors = DCMotor::Vex775Pro(2);
@@ -46,24 +43,23 @@
LinearSystemLoop<2, 1, 1> loop{plant, controller, observer, 12_V, kDt};
};
-void Update(LinearSystemLoop<2, 1, 1>& loop, double noise) {
- Eigen::Matrix<double, 1, 1> y =
- loop.Plant().CalculateY(loop.Xhat(), loop.U()) +
- Eigen::Matrix<double, 1, 1>(noise);
+void Update(const LinearSystem<2, 1, 1>& plant, LinearSystemLoop<2, 1, 1>& loop,
+ double noise) {
+ Eigen::Vector<double, 1> y =
+ plant.CalculateY(loop.Xhat(), loop.U()) + Eigen::Vector<double, 1>{noise};
loop.Correct(y);
loop.Predict(kDt);
}
-TEST_F(StateSpace, CorrectPredictLoop) {
+TEST_F(StateSpaceTest, CorrectPredictLoop) {
std::default_random_engine generator;
std::normal_distribution<double> dist{0.0, kPositionStddev};
- Eigen::Matrix<double, 2, 1> references;
- references << 2.0, 0.0;
+ Eigen::Vector<double, 2> references{2.0, 0.0};
loop.SetNextR(references);
for (int i = 0; i < 1000; i++) {
- Update(loop, dist(generator));
+ Update(plant, loop, dist(generator));
EXPECT_PRED_FORMAT2(testing::DoubleLE, -12.0, loop.U(0));
EXPECT_PRED_FORMAT2(testing::DoubleLE, loop.U(0), 12.0);
}
diff --git a/wpimath/src/test/native/cpp/StateSpaceUtilTest.cpp b/wpimath/src/test/native/cpp/StateSpaceUtilTest.cpp
index 24e9cf2..57b93bb 100644
--- a/wpimath/src/test/native/cpp/StateSpaceUtilTest.cpp
+++ b/wpimath/src/test/native/cpp/StateSpaceUtilTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -11,16 +8,16 @@
#include "Eigen/Core"
#include "frc/StateSpaceUtil.h"
-#include "frc/system/RungeKutta.h"
+#include "frc/system/NumericalIntegration.h"
TEST(StateSpaceUtilTest, MakeMatrix) {
// Column vector
- Eigen::Matrix<double, 2, 1> mat1 = frc::MakeMatrix<2, 1>(1.0, 2.0);
+ Eigen::Vector<double, 2> mat1 = frc::MakeMatrix<2, 1>(1.0, 2.0);
EXPECT_NEAR(mat1(0), 1.0, 1e-3);
EXPECT_NEAR(mat1(1), 2.0, 1e-3);
// Row vector
- Eigen::Matrix<double, 1, 2> mat2 = frc::MakeMatrix<1, 2>(1.0, 2.0);
+ Eigen::RowVector<double, 2> mat2 = frc::MakeMatrix<1, 2>(1.0, 2.0);
EXPECT_NEAR(mat2(0), 1.0, 1e-3);
EXPECT_NEAR(mat2(1), 2.0, 1e-3);
@@ -105,44 +102,59 @@
}
TEST(StateSpaceUtilTest, WhiteNoiseVectorParameterPack) {
- Eigen::Matrix<double, 2, 1> vec = frc::MakeWhiteNoiseVector(2.0, 3.0);
+ Eigen::Vector<double, 2> vec = frc::MakeWhiteNoiseVector(2.0, 3.0);
static_cast<void>(vec);
}
TEST(StateSpaceUtilTest, WhiteNoiseVectorArray) {
- Eigen::Matrix<double, 2, 1> vec = frc::MakeWhiteNoiseVector<2>({2.0, 3.0});
+ Eigen::Vector<double, 2> vec = frc::MakeWhiteNoiseVector<2>({2.0, 3.0});
static_cast<void>(vec);
}
TEST(StateSpaceUtilTest, IsStabilizable) {
- Eigen::Matrix<double, 2, 2> A;
- Eigen::Matrix<double, 2, 1> B;
- B << 0, 1;
-
- // We separate the result of IsStabilizable from the assertion because
- // templates break gtest.
+ Eigen::Matrix<double, 2, 1> B{0, 1};
// First eigenvalue is uncontrollable and unstable.
// Second eigenvalue is controllable and stable.
- A << 1.2, 0, 0, 0.5;
- bool ret = frc::IsStabilizable<2, 1>(A, B);
- EXPECT_FALSE(ret);
+ EXPECT_FALSE((frc::IsStabilizable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{1.2, 0}, {0, 0.5}}, B)));
// First eigenvalue is uncontrollable and marginally stable.
// Second eigenvalue is controllable and stable.
- A << 1, 0, 0, 0.5;
- ret = frc::IsStabilizable<2, 1>(A, B);
- EXPECT_FALSE(ret);
+ EXPECT_FALSE((frc::IsStabilizable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{1, 0}, {0, 0.5}}, B)));
// First eigenvalue is uncontrollable and stable.
// Second eigenvalue is controllable and stable.
- A << 0.2, 0, 0, 0.5;
- ret = frc::IsStabilizable<2, 1>(A, B);
- EXPECT_TRUE(ret);
+ EXPECT_TRUE((frc::IsStabilizable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{0.2, 0}, {0, 0.5}}, B)));
// First eigenvalue is uncontrollable and stable.
// Second eigenvalue is controllable and unstable.
- A << 0.2, 0, 0, 1.2;
- ret = frc::IsStabilizable<2, 1>(A, B);
- EXPECT_TRUE(ret);
+ EXPECT_TRUE((frc::IsStabilizable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{0.2, 0}, {0, 1.2}}, B)));
+}
+
+TEST(StateSpaceUtilTest, IsDetectable) {
+ Eigen::Matrix<double, 1, 2> C{0, 1};
+
+ // First eigenvalue is unobservable and unstable.
+ // Second eigenvalue is observable and stable.
+ EXPECT_FALSE((frc::IsDetectable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{1.2, 0}, {0, 0.5}}, C)));
+
+ // First eigenvalue is unobservable and marginally stable.
+ // Second eigenvalue is observable and stable.
+ EXPECT_FALSE((frc::IsDetectable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{1, 0}, {0, 0.5}}, C)));
+
+ // First eigenvalue is unobservable and stable.
+ // Second eigenvalue is observable and stable.
+ EXPECT_TRUE((frc::IsDetectable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{0.2, 0}, {0, 0.5}}, C)));
+
+ // First eigenvalue is unobservable and stable.
+ // Second eigenvalue is observable and unstable.
+ EXPECT_TRUE((frc::IsDetectable<2, 1>(
+ Eigen::Matrix<double, 2, 2>{{0.2, 0}, {0, 1.2}}, C)));
}
diff --git a/wpimath/src/test/native/cpp/UnitsTest.cpp b/wpimath/src/test/native/cpp/UnitsTest.cpp
index 54e8195..ec38a47 100644
--- a/wpimath/src/test/native/cpp/UnitsTest.cpp
+++ b/wpimath/src/test/native/cpp/UnitsTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <array>
#include <chrono>
@@ -96,66 +93,66 @@
class TypeTraits : public ::testing::Test {
protected:
- TypeTraits() {}
- virtual ~TypeTraits() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ TypeTraits() = default;
+ ~TypeTraits() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class UnitManipulators : public ::testing::Test {
protected:
- UnitManipulators() {}
- virtual ~UnitManipulators() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ UnitManipulators() = default;
+ ~UnitManipulators() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class UnitContainer : public ::testing::Test {
protected:
- UnitContainer() {}
- virtual ~UnitContainer() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ UnitContainer() = default;
+ ~UnitContainer() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class UnitConversion : public ::testing::Test {
protected:
- UnitConversion() {}
- virtual ~UnitConversion() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ UnitConversion() = default;
+ ~UnitConversion() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class UnitMath : public ::testing::Test {
protected:
- UnitMath() {}
- virtual ~UnitMath() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ UnitMath() = default;
+ ~UnitMath() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class CompileTimeArithmetic : public ::testing::Test {
protected:
- CompileTimeArithmetic() {}
- virtual ~CompileTimeArithmetic() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ CompileTimeArithmetic() = default;
+ ~CompileTimeArithmetic() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class Constexpr : public ::testing::Test {
protected:
- Constexpr() {}
- virtual ~Constexpr() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ Constexpr() = default;
+ ~Constexpr() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
};
class CaseStudies : public ::testing::Test {
protected:
- CaseStudies() {}
- virtual ~CaseStudies() {}
- virtual void SetUp() {}
- virtual void TearDown() {}
+ CaseStudies() = default;
+ ~CaseStudies() override = default;
+ void SetUp() override {}
+ void TearDown() override {}
struct RightTriangle {
using a = unit_value_t<meters, 3>;
@@ -1327,7 +1324,7 @@
}
TEST_F(UnitContainer, valueMethod) {
- double test = meter_t(3.0).to<double>();
+ double test = meter_t(3.0).value();
EXPECT_DOUBLE_EQ(3.0, test);
auto test2 = meter_t(4.0).value();
@@ -1336,11 +1333,11 @@
}
TEST_F(UnitContainer, convertMethod) {
- double test = meter_t(3.0).convert<feet>().to<double>();
+ double test = meter_t(3.0).convert<feet>().value();
EXPECT_NEAR(9.84252, test, 5.0e-6);
}
-#ifndef UNIT_LIB_DISABLE_IOSTREAM
+#ifdef UNIT_LIB_ENABLE_IOSTREAM
TEST_F(UnitContainer, cout) {
testing::internal::CaptureStdout();
std::cout << degree_t(349.87);
@@ -1421,6 +1418,88 @@
EXPECT_STREQ("5.670367e-08 kg s^-3 K^-4", output.c_str());
#endif
}
+#endif
+
+TEST_F(UnitContainer, fmtlib) {
+ testing::internal::CaptureStdout();
+ fmt::print("{}", degree_t(349.87));
+ std::string output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("349.87 deg", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", meter_t(1.0));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("1 m", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", dB_t(31.0));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("31 dB", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", volt_t(21.79));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("21.79 V", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", dBW_t(12.0));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("12 dBW", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", dBm_t(120.0));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("120 dBm", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", miles_per_hour_t(72.1));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("72.1 mph", output.c_str());
+
+ // undefined unit
+ testing::internal::CaptureStdout();
+ fmt::print("{}", units::math::cpow<4>(meter_t(2)));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("16 m^4", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{}", units::math::cpow<3>(foot_t(2)));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("8 cu_ft", output.c_str());
+
+ testing::internal::CaptureStdout();
+ fmt::print("{:.9}", units::math::cpow<4>(foot_t(2)));
+ output = testing::internal::GetCapturedStdout();
+ EXPECT_STREQ("0.138095597 m^4", output.c_str());
+
+ // constants
+ testing::internal::CaptureStdout();
+ fmt::print("{:.8}", constants::k_B);
+ output = testing::internal::GetCapturedStdout();
+#if defined(_MSC_VER) && (_MSC_VER <= 1800)
+ EXPECT_STREQ("1.3806485e-023 m^2 kg s^-2 K^-1", output.c_str());
+#else
+ EXPECT_STREQ("1.3806485e-23 m^2 kg s^-2 K^-1", output.c_str());
+#endif
+
+ testing::internal::CaptureStdout();
+ fmt::print("{:.9}", constants::mu_B);
+ output = testing::internal::GetCapturedStdout();
+#if defined(_MSC_VER) && (_MSC_VER <= 1800)
+ EXPECT_STREQ("9.27400999e-024 m^2 A", output.c_str());
+#else
+ EXPECT_STREQ("9.27400999e-24 m^2 A", output.c_str());
+#endif
+
+ testing::internal::CaptureStdout();
+ fmt::print("{:.7}", constants::sigma);
+ output = testing::internal::GetCapturedStdout();
+#if defined(_MSC_VER) && (_MSC_VER <= 1800)
+ EXPECT_STREQ("5.670367e-008 kg s^-3 K^-4", output.c_str());
+#else
+ EXPECT_STREQ("5.670367e-08 kg s^-3 K^-4", output.c_str());
+#endif
+}
TEST_F(UnitContainer, to_string) {
foot_t a(3.5);
@@ -1479,18 +1558,17 @@
EXPECT_STREQ("m", b.abbreviation());
EXPECT_STREQ("meter", b.name());
}
-#endif
TEST_F(UnitContainer, negative) {
meter_t a(5.3);
meter_t b(-5.3);
- EXPECT_NEAR(a.to<double>(), -b.to<double>(), 5.0e-320);
- EXPECT_NEAR(b.to<double>(), -a.to<double>(), 5.0e-320);
+ EXPECT_NEAR(a.value(), -b.value(), 5.0e-320);
+ EXPECT_NEAR(b.value(), -a.value(), 5.0e-320);
dB_t c(2.87);
dB_t d(-2.87);
- EXPECT_NEAR(c.to<double>(), -d.to<double>(), 5.0e-320);
- EXPECT_NEAR(d.to<double>(), -c.to<double>(), 5.0e-320);
+ EXPECT_NEAR(c.value(), -d.value(), 5.0e-320);
+ EXPECT_NEAR(d.value(), -c.value(), 5.0e-320);
ppm_t e = -1 * ppm_t(10);
EXPECT_EQ(e, -ppm_t(10));
@@ -1501,7 +1579,7 @@
ppb_t a(ppm_t(1));
EXPECT_EQ(ppb_t(1000), a);
EXPECT_EQ(0.000001, a);
- EXPECT_EQ(0.000001, a.to<double>());
+ EXPECT_EQ(0.000001, a.value());
scalar_t b(ppm_t(1));
EXPECT_EQ(0.000001, b);
@@ -1711,7 +1789,7 @@
year_t twoYears(2.0);
week_t twoYearsInWeeks = twoYears;
- EXPECT_NEAR(week_t(104.286).to<double>(), twoYearsInWeeks.to<double>(),
+ EXPECT_NEAR(week_t(104.286).value(), twoYearsInWeeks.value(),
5.0e-4);
double test;
@@ -1748,8 +1826,8 @@
TEST_F(UnitConversion, angle) {
angle::degree_t quarterCircleDeg(90.0);
angle::radian_t quarterCircleRad = quarterCircleDeg;
- EXPECT_NEAR(angle::radian_t(constants::detail::PI_VAL / 2.0).to<double>(),
- quarterCircleRad.to<double>(), 5.0e-12);
+ EXPECT_NEAR(angle::radian_t(constants::detail::PI_VAL / 2.0).value(),
+ quarterCircleRad.value(), 5.0e-12);
double test;
@@ -2553,7 +2631,7 @@
EXPECT_TRUE(constants::pi < 4.0);
// explicit conversion
- EXPECT_NEAR(3.14159, constants::pi.to<double>(), 5.0e-6);
+ EXPECT_NEAR(3.14159, constants::pi.value(), 5.0e-6);
// auto multiplication
EXPECT_TRUE(
@@ -2562,16 +2640,16 @@
(std::is_same<meter_t, decltype(meter_t(1) * constants::pi)>::value));
EXPECT_NEAR(constants::detail::PI_VAL,
- (constants::pi * meter_t(1)).to<double>(), 5.0e-10);
+ (constants::pi * meter_t(1)).value(), 5.0e-10);
EXPECT_NEAR(constants::detail::PI_VAL,
- (meter_t(1) * constants::pi).to<double>(), 5.0e-10);
+ (meter_t(1) * constants::pi).value(), 5.0e-10);
// explicit multiplication
meter_t a = constants::pi * meter_t(1);
meter_t b = meter_t(1) * constants::pi;
- EXPECT_NEAR(constants::detail::PI_VAL, a.to<double>(), 5.0e-10);
- EXPECT_NEAR(constants::detail::PI_VAL, b.to<double>(), 5.0e-10);
+ EXPECT_NEAR(constants::detail::PI_VAL, a.value(), 5.0e-10);
+ EXPECT_NEAR(constants::detail::PI_VAL, b.value(), 5.0e-10);
// auto division
EXPECT_TRUE(
@@ -2580,16 +2658,16 @@
(std::is_same<second_t, decltype(second_t(1) / constants::pi)>::value));
EXPECT_NEAR(constants::detail::PI_VAL,
- (constants::pi / second_t(1)).to<double>(), 5.0e-10);
+ (constants::pi / second_t(1)).value(), 5.0e-10);
EXPECT_NEAR(1.0 / constants::detail::PI_VAL,
- (second_t(1) / constants::pi).to<double>(), 5.0e-10);
+ (second_t(1) / constants::pi).value(), 5.0e-10);
// explicit
hertz_t c = constants::pi / second_t(1);
second_t d = second_t(1) / constants::pi;
- EXPECT_NEAR(constants::detail::PI_VAL, c.to<double>(), 5.0e-10);
- EXPECT_NEAR(1.0 / constants::detail::PI_VAL, d.to<double>(), 5.0e-10);
+ EXPECT_NEAR(constants::detail::PI_VAL, c.value(), 5.0e-10);
+ EXPECT_NEAR(1.0 / constants::detail::PI_VAL, d.value(), 5.0e-10);
}
TEST_F(UnitConversion, constants) {
@@ -2696,12 +2774,12 @@
(std::is_same<
typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(acos(scalar_t(0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(2).to<double>(),
- acos(scalar_t(-0.41614683654)).to<double>(), 5.0e-11);
+ EXPECT_NEAR(angle::radian_t(2).value(),
+ acos(scalar_t(-0.41614683654)).value(), 5.0e-11);
EXPECT_NEAR(
- angle::degree_t(135).to<double>(),
+ angle::degree_t(135).value(),
angle::degree_t(acos(scalar_t(-0.70710678118654752440084436210485)))
- .to<double>(),
+ .value(),
5.0e-12);
}
@@ -2710,12 +2788,12 @@
(std::is_same<
typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(asin(scalar_t(0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(1.14159265).to<double>(),
- asin(scalar_t(0.90929742682)).to<double>(), 5.0e-9);
+ EXPECT_NEAR(angle::radian_t(1.14159265).value(),
+ asin(scalar_t(0.90929742682)).value(), 5.0e-9);
EXPECT_NEAR(
- angle::degree_t(45).to<double>(),
+ angle::degree_t(45).value(),
angle::degree_t(asin(scalar_t(0.70710678118654752440084436210485)))
- .to<double>(),
+ .value(),
5.0e-12);
}
@@ -2724,32 +2802,32 @@
(std::is_same<
typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(atan(scalar_t(0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(-1.14159265).to<double>(),
- atan(scalar_t(-2.18503986326)).to<double>(), 5.0e-9);
- EXPECT_NEAR(angle::degree_t(-45).to<double>(),
- angle::degree_t(atan(scalar_t(-1.0))).to<double>(), 5.0e-12);
+ EXPECT_NEAR(angle::radian_t(-1.14159265).value(),
+ atan(scalar_t(-2.18503986326)).value(), 5.0e-9);
+ EXPECT_NEAR(angle::degree_t(-45).value(),
+ angle::degree_t(atan(scalar_t(-1.0))).value(), 5.0e-12);
}
TEST_F(UnitMath, atan2) {
EXPECT_TRUE((std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(atan2(
scalar_t(1), scalar_t(1)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(constants::detail::PI_VAL / 4).to<double>(),
- atan2(scalar_t(2), scalar_t(2)).to<double>(), 5.0e-12);
+ EXPECT_NEAR(angle::radian_t(constants::detail::PI_VAL / 4).value(),
+ atan2(scalar_t(2), scalar_t(2)).value(), 5.0e-12);
EXPECT_NEAR(
- angle::degree_t(45).to<double>(),
- angle::degree_t(atan2(scalar_t(2), scalar_t(2))).to<double>(),
+ angle::degree_t(45).value(),
+ angle::degree_t(atan2(scalar_t(2), scalar_t(2))).value(),
5.0e-12);
EXPECT_TRUE((std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(atan2(
scalar_t(1), scalar_t(1)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(constants::detail::PI_VAL / 6).to<double>(),
- atan2(scalar_t(1), scalar_t(std::sqrt(3))).to<double>(),
+ EXPECT_NEAR(angle::radian_t(constants::detail::PI_VAL / 6).value(),
+ atan2(scalar_t(1), scalar_t(std::sqrt(3))).value(),
5.0e-12);
- EXPECT_NEAR(angle::degree_t(30).to<double>(),
+ EXPECT_NEAR(angle::degree_t(30).value(),
angle::degree_t(atan2(scalar_t(1), scalar_t(std::sqrt(3))))
- .to<double>(),
+ .value(),
5.0e-12);
}
@@ -2781,30 +2859,30 @@
EXPECT_TRUE((std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(
acosh(scalar_t(0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(1.316957896924817).to<double>(),
- acosh(scalar_t(2.0)).to<double>(), 5.0e-11);
- EXPECT_NEAR(angle::degree_t(75.456129290216893).to<double>(),
- angle::degree_t(acosh(scalar_t(2.0))).to<double>(), 5.0e-12);
+ EXPECT_NEAR(angle::radian_t(1.316957896924817).value(),
+ acosh(scalar_t(2.0)).value(), 5.0e-11);
+ EXPECT_NEAR(angle::degree_t(75.456129290216893).value(),
+ angle::degree_t(acosh(scalar_t(2.0))).value(), 5.0e-12);
}
TEST_F(UnitMath, asinh) {
EXPECT_TRUE((std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(
asinh(scalar_t(0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(1.443635475178810).to<double>(),
- asinh(scalar_t(2)).to<double>(), 5.0e-9);
- EXPECT_NEAR(angle::degree_t(82.714219883108939).to<double>(),
- angle::degree_t(asinh(scalar_t(2))).to<double>(), 5.0e-12);
+ EXPECT_NEAR(angle::radian_t(1.443635475178810).value(),
+ asinh(scalar_t(2)).value(), 5.0e-9);
+ EXPECT_NEAR(angle::degree_t(82.714219883108939).value(),
+ angle::degree_t(asinh(scalar_t(2))).value(), 5.0e-12);
}
TEST_F(UnitMath, atanh) {
EXPECT_TRUE((std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(
atanh(scalar_t(0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(0.549306144334055).to<double>(),
- atanh(scalar_t(0.5)).to<double>(), 5.0e-9);
- EXPECT_NEAR(angle::degree_t(31.472923730945389).to<double>(),
- angle::degree_t(atanh(scalar_t(0.5))).to<double>(), 5.0e-12);
+ EXPECT_NEAR(angle::radian_t(0.549306144334055).value(),
+ atanh(scalar_t(0.5)).value(), 5.0e-9);
+ EXPECT_NEAR(angle::degree_t(31.472923730945389).value(),
+ angle::degree_t(atanh(scalar_t(0.5))).value(), 5.0e-12);
}
TEST_F(UnitMath, exp) {
@@ -2876,14 +2954,14 @@
EXPECT_TRUE((std::is_same<typename std::decay<meter_t>::type,
typename std::decay<decltype(sqrt(
square_meter_t(4.0)))>::type>::value));
- EXPECT_NEAR(meter_t(2.0).to<double>(),
- sqrt(square_meter_t(4.0)).to<double>(), 5.0e-9);
+ EXPECT_NEAR(meter_t(2.0).value(),
+ sqrt(square_meter_t(4.0)).value(), 5.0e-9);
EXPECT_TRUE((std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(
sqrt(steradian_t(16.0)))>::type>::value));
- EXPECT_NEAR(angle::radian_t(4.0).to<double>(),
- sqrt(steradian_t(16.0)).to<double>(), 5.0e-9);
+ EXPECT_NEAR(angle::radian_t(4.0).value(),
+ sqrt(steradian_t(16.0)).value(), 5.0e-9);
EXPECT_TRUE((std::is_convertible<typename std::decay<foot_t>::type,
typename std::decay<decltype(sqrt(
@@ -2892,9 +2970,9 @@
// for rational conversion (i.e. no integral root) let's check a bunch of
// different ways this could go wrong
foot_t resultFt = sqrt(square_foot_t(10.0));
- EXPECT_NEAR(foot_t(3.16227766017).to<double>(),
- sqrt(square_foot_t(10.0)).to<double>(), 5.0e-9);
- EXPECT_NEAR(foot_t(3.16227766017).to<double>(), resultFt.to<double>(),
+ EXPECT_NEAR(foot_t(3.16227766017).value(),
+ sqrt(square_foot_t(10.0)).value(), 5.0e-9);
+ EXPECT_NEAR(foot_t(3.16227766017).value(), resultFt.value(),
5.0e-9);
EXPECT_EQ(resultFt, sqrt(square_foot_t(10.0)));
}
@@ -2903,19 +2981,19 @@
EXPECT_TRUE((std::is_same<typename std::decay<meter_t>::type,
typename std::decay<decltype(hypot(
meter_t(3.0), meter_t(4.0)))>::type>::value));
- EXPECT_NEAR(meter_t(5.0).to<double>(),
- (hypot(meter_t(3.0), meter_t(4.0))).to<double>(), 5.0e-9);
+ EXPECT_NEAR(meter_t(5.0).value(),
+ (hypot(meter_t(3.0), meter_t(4.0))).value(), 5.0e-9);
EXPECT_TRUE((std::is_same<typename std::decay<foot_t>::type,
typename std::decay<decltype(hypot(
foot_t(3.0), meter_t(1.2192)))>::type>::value));
- EXPECT_NEAR(foot_t(5.0).to<double>(),
- (hypot(foot_t(3.0), meter_t(1.2192))).to<double>(), 5.0e-9);
+ EXPECT_NEAR(foot_t(5.0).value(),
+ (hypot(foot_t(3.0), meter_t(1.2192))).value(), 5.0e-9);
}
TEST_F(UnitMath, ceil) {
double val = 101.1;
- EXPECT_EQ(std::ceil(val), ceil(meter_t(val)).to<double>());
+ EXPECT_EQ(std::ceil(val), ceil(meter_t(val)).value());
EXPECT_TRUE((std::is_same<typename std::decay<meter_t>::type,
typename std::decay<decltype(
ceil(meter_t(val)))>::type>::value));
@@ -2928,7 +3006,7 @@
TEST_F(UnitMath, fmod) {
EXPECT_EQ(std::fmod(100.0, 101.2),
- fmod(meter_t(100.0), meter_t(101.2)).to<double>());
+ fmod(meter_t(100.0), meter_t(101.2)).value());
}
TEST_F(UnitMath, trunc) {
@@ -2951,8 +3029,8 @@
TEST_F(UnitMath, fdim) {
EXPECT_EQ(meter_t(0.0), fdim(meter_t(8.0), meter_t(10.0)));
EXPECT_EQ(meter_t(2.0), fdim(meter_t(10.0), meter_t(8.0)));
- EXPECT_NEAR(meter_t(9.3904).to<double>(),
- fdim(meter_t(10.0), foot_t(2.0)).to<double>(),
+ EXPECT_NEAR(meter_t(9.3904).value(),
+ fdim(meter_t(10.0), foot_t(2.0)).value(),
5.0e-320); // not sure why they aren't comparing exactly equal,
// but clearly they are.
}
@@ -2979,17 +3057,6 @@
EXPECT_EQ(meter_t(10.0), abs(meter_t(10.0)));
}
-TEST_F(UnitMath, normalize) {
- EXPECT_EQ(NormalizeAngle(radian_t(5 * wpi::math::pi)),
- radian_t(wpi::math::pi));
- EXPECT_EQ(NormalizeAngle(radian_t(-5 * wpi::math::pi)),
- radian_t(wpi::math::pi));
- EXPECT_EQ(NormalizeAngle(radian_t(wpi::math::pi / 2)),
- radian_t(wpi::math::pi / 2));
- EXPECT_EQ(NormalizeAngle(radian_t(-wpi::math::pi / 2)),
- radian_t(-wpi::math::pi / 2));
-}
-
// Constexpr
#if !defined(_MSC_VER) || _MSC_VER > 1800
TEST_F(Constexpr, construction) {
@@ -3096,7 +3163,7 @@
}
TEST_F(CompileTimeArithmetic, is_unit_value_t) {
- typedef unit_value_t<meters, 3, 2> mRatio;
+ using mRatio = unit_value_t<meters, 3, 2>;
EXPECT_TRUE((traits::is_unit_value_t<mRatio>::value));
EXPECT_FALSE((traits::is_unit_value_t<meter_t>::value));
@@ -3108,7 +3175,7 @@
}
TEST_F(CompileTimeArithmetic, is_unit_value_t_category) {
- typedef unit_value_t<feet, 3, 2> mRatio;
+ using mRatio = unit_value_t<feet, 3, 2>;
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, mRatio>::value));
EXPECT_FALSE(
@@ -3120,90 +3187,90 @@
}
TEST_F(CompileTimeArithmetic, unit_value_add) {
- typedef unit_value_t<meters, 3, 2> mRatio;
+ using mRatio = unit_value_t<meters, 3, 2>;
using sum = unit_value_add<mRatio, mRatio>;
EXPECT_EQ(meter_t(3.0), sum::value());
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, sum>::value));
- typedef unit_value_t<feet, 1> ftRatio;
+ using ftRatio = unit_value_t<feet, 1>;
using sumf = unit_value_add<ftRatio, mRatio>;
EXPECT_TRUE((
std::is_same<typename std::decay<foot_t>::type,
typename std::decay<decltype(sumf::value())>::type>::value));
- EXPECT_NEAR(5.92125984, sumf::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(5.92125984, sumf::value().value(), 5.0e-8);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, sumf>::value));
- typedef unit_value_t<celsius, 1> cRatio;
- typedef unit_value_t<fahrenheit, 2> fRatio;
+ using cRatio = unit_value_t<celsius, 1>;
+ using fRatio = unit_value_t<fahrenheit, 2>;
using sumc = unit_value_add<cRatio, fRatio>;
EXPECT_TRUE((
std::is_same<typename std::decay<celsius_t>::type,
typename std::decay<decltype(sumc::value())>::type>::value));
- EXPECT_NEAR(2.11111111111, sumc::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(2.11111111111, sumc::value().value(), 5.0e-8);
EXPECT_TRUE((traits::is_unit_value_t_category<category::temperature_unit,
sumc>::value));
- typedef unit_value_t<angle::radian, 1> rRatio;
- typedef unit_value_t<angle::degree, 3> dRatio;
+ using rRatio = unit_value_t<angle::radian, 1>;
+ using dRatio = unit_value_t<angle::degree, 3>;
using sumr = unit_value_add<rRatio, dRatio>;
EXPECT_TRUE((
std::is_same<typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(sumr::value())>::type>::value));
- EXPECT_NEAR(1.05235988, sumr::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(1.05235988, sumr::value().value(), 5.0e-8);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::angle_unit, sumr>::value));
}
TEST_F(CompileTimeArithmetic, unit_value_subtract) {
- typedef unit_value_t<meters, 3, 2> mRatio;
+ using mRatio = unit_value_t<meters, 3, 2>;
using diff = unit_value_subtract<mRatio, mRatio>;
EXPECT_EQ(meter_t(0), diff::value());
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, diff>::value));
- typedef unit_value_t<feet, 1> ftRatio;
+ using ftRatio = unit_value_t<feet, 1>;
using difff = unit_value_subtract<ftRatio, mRatio>;
EXPECT_TRUE((std::is_same<
typename std::decay<foot_t>::type,
typename std::decay<decltype(difff::value())>::type>::value));
- EXPECT_NEAR(-3.92125984, difff::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(-3.92125984, difff::value().value(), 5.0e-8);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, difff>::value));
- typedef unit_value_t<celsius, 1> cRatio;
- typedef unit_value_t<fahrenheit, 2> fRatio;
+ using cRatio = unit_value_t<celsius, 1>;
+ using fRatio = unit_value_t<fahrenheit, 2>;
using diffc = unit_value_subtract<cRatio, fRatio>;
EXPECT_TRUE((std::is_same<
typename std::decay<celsius_t>::type,
typename std::decay<decltype(diffc::value())>::type>::value));
- EXPECT_NEAR(-0.11111111111, diffc::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(-0.11111111111, diffc::value().value(), 5.0e-8);
EXPECT_TRUE((traits::is_unit_value_t_category<category::temperature_unit,
diffc>::value));
- typedef unit_value_t<angle::radian, 1> rRatio;
- typedef unit_value_t<angle::degree, 3> dRatio;
+ using rRatio = unit_value_t<angle::radian, 1>;
+ using dRatio = unit_value_t<angle::degree, 3>;
using diffr = unit_value_subtract<rRatio, dRatio>;
EXPECT_TRUE((std::is_same<
typename std::decay<angle::radian_t>::type,
typename std::decay<decltype(diffr::value())>::type>::value));
- EXPECT_NEAR(0.947640122, diffr::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(0.947640122, diffr::value().value(), 5.0e-8);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::angle_unit, diffr>::value));
}
TEST_F(CompileTimeArithmetic, unit_value_multiply) {
- typedef unit_value_t<meters, 2> mRatio;
- typedef unit_value_t<feet, 656168, 100000> ftRatio; // 2 meter
+ using mRatio = unit_value_t<meters, 2>;
+ using ftRatio = unit_value_t<feet, 656168, 100000>; // 2 meter
using product = unit_value_multiply<mRatio, mRatio>;
EXPECT_EQ(square_meter_t(4), product::value());
@@ -3215,7 +3282,7 @@
EXPECT_TRUE((std::is_same<
typename std::decay<square_meter_t>::type,
typename std::decay<decltype(productM::value())>::type>::value));
- EXPECT_NEAR(4.0, productM::value().to<double>(), 5.0e-7);
+ EXPECT_NEAR(4.0, productM::value().value(), 5.0e-7);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::area_unit, productM>::value));
@@ -3223,7 +3290,7 @@
EXPECT_TRUE((std::is_same<
typename std::decay<square_foot_t>::type,
typename std::decay<decltype(productF::value())>::type>::value));
- EXPECT_NEAR(43.0556444224, productF::value().to<double>(), 5.0e-6);
+ EXPECT_NEAR(43.0556444224, productF::value().value(), 5.0e-6);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::area_unit, productF>::value));
@@ -3232,11 +3299,11 @@
(std::is_same<
typename std::decay<square_foot_t>::type,
typename std::decay<decltype(productF2::value())>::type>::value));
- EXPECT_NEAR(43.0556444224, productF2::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(43.0556444224, productF2::value().value(), 5.0e-8);
EXPECT_TRUE((
traits::is_unit_value_t_category<category::area_unit, productF2>::value));
- typedef unit_value_t<units::force::newton, 5> nRatio;
+ using nRatio = unit_value_t<units::force::newton, 5>;
using productN = unit_value_multiply<nRatio, ftRatio>;
EXPECT_FALSE(
@@ -3246,30 +3313,30 @@
EXPECT_TRUE((std::is_convertible<
typename std::decay<torque::newton_meter_t>::type,
typename std::decay<decltype(productN::value())>::type>::value));
- EXPECT_NEAR(32.8084, productN::value().to<double>(), 5.0e-8);
- EXPECT_NEAR(10.0, (productN::value().convert<newton_meter>().to<double>()),
+ EXPECT_NEAR(32.8084, productN::value().value(), 5.0e-8);
+ EXPECT_NEAR(10.0, (productN::value().convert<newton_meter>().value()),
5.0e-7);
EXPECT_TRUE((traits::is_unit_value_t_category<category::torque_unit,
productN>::value));
- typedef unit_value_t<angle::radian, 11, 10> r1Ratio;
- typedef unit_value_t<angle::radian, 22, 10> r2Ratio;
+ using r1Ratio = unit_value_t<angle::radian, 11, 10>;
+ using r2Ratio = unit_value_t<angle::radian, 22, 10>;
using productR = unit_value_multiply<r1Ratio, r2Ratio>;
EXPECT_TRUE((std::is_same<
typename std::decay<steradian_t>::type,
typename std::decay<decltype(productR::value())>::type>::value));
- EXPECT_NEAR(2.42, productR::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(2.42, productR::value().value(), 5.0e-8);
EXPECT_NEAR(7944.39137,
- (productR::value().convert<degrees_squared>().to<double>()),
+ (productR::value().convert<degrees_squared>().value()),
5.0e-6);
EXPECT_TRUE((traits::is_unit_value_t_category<category::solid_angle_unit,
productR>::value));
}
TEST_F(CompileTimeArithmetic, unit_value_divide) {
- typedef unit_value_t<meters, 2> mRatio;
- typedef unit_value_t<feet, 656168, 100000> ftRatio; // 2 meter
+ using mRatio = unit_value_t<meters, 2>;
+ using ftRatio = unit_value_t<feet, 656168, 100000>; // 2 meter
using product = unit_value_divide<mRatio, mRatio>;
EXPECT_EQ(scalar_t(1), product::value());
@@ -3281,7 +3348,7 @@
EXPECT_TRUE((std::is_same<
typename std::decay<scalar_t>::type,
typename std::decay<decltype(productM::value())>::type>::value));
- EXPECT_NEAR(1, productM::value().to<double>(), 5.0e-7);
+ EXPECT_NEAR(1, productM::value().value(), 5.0e-7);
EXPECT_TRUE((traits::is_unit_value_t_category<category::scalar_unit,
productM>::value));
@@ -3289,7 +3356,7 @@
EXPECT_TRUE((std::is_same<
typename std::decay<scalar_t>::type,
typename std::decay<decltype(productF::value())>::type>::value));
- EXPECT_NEAR(1.0, productF::value().to<double>(), 5.0e-6);
+ EXPECT_NEAR(1.0, productF::value().value(), 5.0e-6);
EXPECT_TRUE((traits::is_unit_value_t_category<category::scalar_unit,
productF>::value));
@@ -3298,90 +3365,90 @@
(std::is_same<
typename std::decay<scalar_t>::type,
typename std::decay<decltype(productF2::value())>::type>::value));
- EXPECT_NEAR(1.0, productF2::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(1.0, productF2::value().value(), 5.0e-8);
EXPECT_TRUE((traits::is_unit_value_t_category<category::scalar_unit,
productF2>::value));
- typedef unit_value_t<seconds, 10> sRatio;
+ using sRatio = unit_value_t<seconds, 10>;
using productMS = unit_value_divide<mRatio, sRatio>;
EXPECT_TRUE(
(std::is_same<
typename std::decay<meters_per_second_t>::type,
typename std::decay<decltype(productMS::value())>::type>::value));
- EXPECT_NEAR(0.2, productMS::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(0.2, productMS::value().value(), 5.0e-8);
EXPECT_TRUE((traits::is_unit_value_t_category<category::velocity_unit,
productMS>::value));
- typedef unit_value_t<angle::radian, 20> rRatio;
+ using rRatio = unit_value_t<angle::radian, 20>;
using productRS = unit_value_divide<rRatio, sRatio>;
EXPECT_TRUE(
(std::is_same<
typename std::decay<radians_per_second_t>::type,
typename std::decay<decltype(productRS::value())>::type>::value));
- EXPECT_NEAR(2, productRS::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(2, productRS::value().value(), 5.0e-8);
EXPECT_NEAR(114.592,
- (productRS::value().convert<degrees_per_second>().to<double>()),
+ (productRS::value().convert<degrees_per_second>().value()),
5.0e-4);
EXPECT_TRUE((traits::is_unit_value_t_category<category::angular_velocity_unit,
productRS>::value));
}
TEST_F(CompileTimeArithmetic, unit_value_power) {
- typedef unit_value_t<meters, 2> mRatio;
+ using mRatio = unit_value_t<meters, 2>;
using sq = unit_value_power<mRatio, 2>;
EXPECT_TRUE((std::is_convertible<
typename std::decay<square_meter_t>::type,
typename std::decay<decltype(sq::value())>::type>::value));
- EXPECT_NEAR(4, sq::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(4, sq::value().value(), 5.0e-8);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::area_unit, sq>::value));
- typedef unit_value_t<angle::radian, 18, 10> rRatio;
+ using rRatio = unit_value_t<angle::radian, 18, 10>;
using sqr = unit_value_power<rRatio, 2>;
EXPECT_TRUE((std::is_convertible<
typename std::decay<steradian_t>::type,
typename std::decay<decltype(sqr::value())>::type>::value));
- EXPECT_NEAR(3.24, sqr::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(3.24, sqr::value().value(), 5.0e-8);
EXPECT_NEAR(10636.292574038049895092690529904,
- (sqr::value().convert<degrees_squared>().to<double>()), 5.0e-10);
+ (sqr::value().convert<degrees_squared>().value()), 5.0e-10);
EXPECT_TRUE((traits::is_unit_value_t_category<category::solid_angle_unit,
sqr>::value));
}
TEST_F(CompileTimeArithmetic, unit_value_sqrt) {
- typedef unit_value_t<square_meters, 10> mRatio;
+ using mRatio = unit_value_t<square_meters, 10>;
using root = unit_value_sqrt<mRatio>;
EXPECT_TRUE((std::is_convertible<
typename std::decay<meter_t>::type,
typename std::decay<decltype(root::value())>::type>::value));
- EXPECT_NEAR(3.16227766017, root::value().to<double>(), 5.0e-9);
+ EXPECT_NEAR(3.16227766017, root::value().value(), 5.0e-9);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, root>::value));
- typedef unit_value_t<hectare, 51, 7> hRatio;
+ using hRatio = unit_value_t<hectare, 51, 7>;
using rooth = unit_value_sqrt<hRatio, 100000000>;
EXPECT_TRUE((std::is_convertible<
typename std::decay<mile_t>::type,
typename std::decay<decltype(rooth::value())>::type>::value));
- EXPECT_NEAR(2.69920623253, rooth::value().to<double>(), 5.0e-8);
- EXPECT_NEAR(269.920623, rooth::value().convert<meters>().to<double>(),
+ EXPECT_NEAR(2.69920623253, rooth::value().value(), 5.0e-8);
+ EXPECT_NEAR(269.920623, rooth::value().convert<meters>().value(),
5.0e-6);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::length_unit, rooth>::value));
- typedef unit_value_t<steradian, 18, 10> rRatio;
+ using rRatio = unit_value_t<steradian, 18, 10>;
using rootr = unit_value_sqrt<rRatio>;
EXPECT_TRUE((traits::is_angle_unit<decltype(rootr::value())>::value));
- EXPECT_NEAR(1.3416407865, rootr::value().to<double>(), 5.0e-8);
+ EXPECT_NEAR(1.3416407865, rootr::value().value(), 5.0e-8);
EXPECT_NEAR(76.870352574,
- rootr::value().convert<angle::degrees>().to<double>(), 5.0e-6);
+ rootr::value().convert<angle::degrees>().value(), 5.0e-6);
EXPECT_TRUE(
(traits::is_unit_value_t_category<category::angle_unit, rootr>::value));
}
diff --git a/wpimath/src/test/native/cpp/controller/ControlAffinePlantInversionFeedforwardTest.cpp b/wpimath/src/test/native/cpp/controller/ControlAffinePlantInversionFeedforwardTest.cpp
index 79f1a6d..354ed18 100644
--- a/wpimath/src/test/native/cpp/controller/ControlAffinePlantInversionFeedforwardTest.cpp
+++ b/wpimath/src/test/native/cpp/controller/ControlAffinePlantInversionFeedforwardTest.cpp
@@ -1,70 +1,52 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
#include <cmath>
#include "Eigen/Core"
-#include "frc/StateSpaceUtil.h"
#include "frc/controller/ControlAffinePlantInversionFeedforward.h"
#include "units/time.h"
namespace frc {
-Eigen::Matrix<double, 2, 1> Dynamics(const Eigen::Matrix<double, 2, 1>& x,
- const Eigen::Matrix<double, 1, 1>& u) {
- Eigen::Matrix<double, 2, 1> result;
-
- result = (frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 1.0) * x) +
- (frc::MakeMatrix<2, 1>(0.0, 1.0) * u);
-
- return result;
+Eigen::Vector<double, 2> Dynamics(const Eigen::Vector<double, 2>& x,
+ const Eigen::Vector<double, 1>& u) {
+ return Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 1.0}} * x +
+ Eigen::Matrix<double, 2, 1>{0.0, 1.0} * u;
}
-Eigen::Matrix<double, 2, 1> StateDynamics(
- const Eigen::Matrix<double, 2, 1>& x) {
- Eigen::Matrix<double, 2, 1> result;
-
- result = (frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 1.0) * x);
-
- return result;
+Eigen::Vector<double, 2> StateDynamics(const Eigen::Vector<double, 2>& x) {
+ return Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 1.0}} * x;
}
TEST(ControlAffinePlantInversionFeedforwardTest, Calculate) {
- std::function<Eigen::Matrix<double, 2, 1>(const Eigen::Matrix<double, 2, 1>&,
- const Eigen::Matrix<double, 1, 1>&)>
+ std::function<Eigen::Vector<double, 2>(const Eigen::Vector<double, 2>&,
+ const Eigen::Vector<double, 1>&)>
modelDynamics = [](auto& x, auto& u) { return Dynamics(x, u); };
frc::ControlAffinePlantInversionFeedforward<2, 1> feedforward{
- modelDynamics, units::second_t(0.02)};
+ modelDynamics, units::second_t{0.02}};
- Eigen::Matrix<double, 2, 1> r;
- r << 2, 2;
- Eigen::Matrix<double, 2, 1> nextR;
- nextR << 3, 3;
+ Eigen::Vector<double, 2> r{2, 2};
+ Eigen::Vector<double, 2> nextR{3, 3};
EXPECT_NEAR(48, feedforward.Calculate(r, nextR)(0, 0), 1e-6);
}
TEST(ControlAffinePlantInversionFeedforwardTest, CalculateState) {
- std::function<Eigen::Matrix<double, 2, 1>(const Eigen::Matrix<double, 2, 1>&)>
+ std::function<Eigen::Vector<double, 2>(const Eigen::Vector<double, 2>&)>
modelDynamics = [](auto& x) { return StateDynamics(x); };
- Eigen::Matrix<double, 2, 1> B;
- B << 0, 1;
+ Eigen::Matrix<double, 2, 1> B{0, 1};
frc::ControlAffinePlantInversionFeedforward<2, 1> feedforward{
modelDynamics, B, units::second_t(0.02)};
- Eigen::Matrix<double, 2, 1> r;
- r << 2, 2;
- Eigen::Matrix<double, 2, 1> nextR;
- nextR << 3, 3;
+ Eigen::Vector<double, 2> r{2, 2};
+ Eigen::Vector<double, 2> nextR{3, 3};
EXPECT_NEAR(48, feedforward.Calculate(r, nextR)(0, 0), 1e-6);
}
diff --git a/wpimath/src/test/native/cpp/controller/HolonomicDriveControllerTest.cpp b/wpimath/src/test/native/cpp/controller/HolonomicDriveControllerTest.cpp
new file mode 100644
index 0000000..c6b669c
--- /dev/null
+++ b/wpimath/src/test/native/cpp/controller/HolonomicDriveControllerTest.cpp
@@ -0,0 +1,66 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <wpi/numbers>
+
+#include "frc/MathUtil.h"
+#include "frc/controller/HolonomicDriveController.h"
+#include "frc/trajectory/TrajectoryGenerator.h"
+#include "gtest/gtest.h"
+#include "units/angular_acceleration.h"
+#include "units/math.h"
+#include "units/time.h"
+
+#define EXPECT_NEAR_UNITS(val1, val2, eps) \
+ EXPECT_LE(units::math::abs(val1 - val2), eps)
+
+static constexpr units::meter_t kTolerance{1 / 12.0};
+static constexpr units::radian_t kAngularTolerance{2.0 * wpi::numbers::pi /
+ 180.0};
+
+TEST(HolonomicDriveControllerTest, ReachesReference) {
+ frc::HolonomicDriveController controller{
+ frc2::PIDController{1.0, 0.0, 0.0}, frc2::PIDController{1.0, 0.0, 0.0},
+ frc::ProfiledPIDController<units::radian>{
+ 1.0, 0.0, 0.0,
+ frc::TrapezoidProfile<units::radian>::Constraints{
+ units::radians_per_second_t{2.0 * wpi::numbers::pi},
+ units::radians_per_second_squared_t{wpi::numbers::pi}}}};
+
+ frc::Pose2d robotPose{2.7_m, 23_m, frc::Rotation2d{0_deg}};
+
+ auto waypoints = std::vector{frc::Pose2d{2.75_m, 22.521_m, 0_rad},
+ frc::Pose2d{24.73_m, 19.68_m, 5.846_rad}};
+ auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
+ waypoints, {8.0_mps, 4.0_mps_sq});
+
+ constexpr auto kDt = 0.02_s;
+ auto totalTime = trajectory.TotalTime();
+ for (size_t i = 0; i < (totalTime / kDt).value(); ++i) {
+ auto state = trajectory.Sample(kDt * i);
+ auto [vx, vy, omega] = controller.Calculate(robotPose, state, 0_rad);
+
+ robotPose = robotPose.Exp(frc::Twist2d{vx * kDt, vy * kDt, omega * kDt});
+ }
+
+ auto& endPose = trajectory.States().back().pose;
+ EXPECT_NEAR_UNITS(endPose.X(), robotPose.X(), kTolerance);
+ EXPECT_NEAR_UNITS(endPose.Y(), robotPose.Y(), kTolerance);
+ EXPECT_NEAR_UNITS(frc::AngleModulus(robotPose.Rotation().Radians()), 0_rad,
+ kAngularTolerance);
+}
+
+TEST(HolonomicDriveControllerTest, DoesNotRotateUnnecessarily) {
+ frc::HolonomicDriveController controller{
+ frc2::PIDController{1, 0, 0}, frc2::PIDController{1, 0, 0},
+ frc::ProfiledPIDController<units::radian>{
+ 1, 0, 0,
+ frc::TrapezoidProfile<units::radian>::Constraints{
+ 4_rad_per_s, 2_rad_per_s / 1_s}}};
+
+ frc::ChassisSpeeds speeds = controller.Calculate(
+ frc::Pose2d(0_m, 0_m, 1.57_rad), frc::Pose2d(), 0_mps, 1.57_rad);
+
+ EXPECT_EQ(0, speeds.omega.value());
+}
diff --git a/wpimath/src/test/native/cpp/controller/LinearPlantInversionFeedforwardTest.cpp b/wpimath/src/test/native/cpp/controller/LinearPlantInversionFeedforwardTest.cpp
index abc22d9..6e61706 100644
--- a/wpimath/src/test/native/cpp/controller/LinearPlantInversionFeedforwardTest.cpp
+++ b/wpimath/src/test/native/cpp/controller/LinearPlantInversionFeedforwardTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -16,19 +13,14 @@
namespace frc {
TEST(LinearPlantInversionFeedforwardTest, Calculate) {
- Eigen::Matrix<double, 2, 2> A;
- A << 1, 0, 0, 1;
-
- Eigen::Matrix<double, 2, 1> B;
- B << 0, 1;
+ Eigen::Matrix<double, 2, 2> A{{1, 0}, {0, 1}};
+ Eigen::Matrix<double, 2, 1> B{0, 1};
frc::LinearPlantInversionFeedforward<2, 1> feedforward{A, B,
units::second_t(0.02)};
- Eigen::Matrix<double, 2, 1> r;
- r << 2, 2;
- Eigen::Matrix<double, 2, 1> nextR;
- nextR << 3, 3;
+ Eigen::Vector<double, 2> r{2, 2};
+ Eigen::Vector<double, 2> nextR{3, 3};
EXPECT_NEAR(47.502599, feedforward.Calculate(r, nextR)(0, 0), 0.002);
}
diff --git a/wpimath/src/test/native/cpp/controller/LinearQuadraticRegulatorTest.cpp b/wpimath/src/test/native/cpp/controller/LinearQuadraticRegulatorTest.cpp
index eecaf34..8c52cd0 100644
--- a/wpimath/src/test/native/cpp/controller/LinearQuadraticRegulatorTest.cpp
+++ b/wpimath/src/test/native/cpp/controller/LinearQuadraticRegulatorTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -85,4 +82,27 @@
EXPECT_NEAR(0.69, controller.K(0, 1), 1e-1);
}
+TEST(LinearQuadraticRegulatorTest, LatencyCompensate) {
+ LinearSystem<2, 1, 1> plant = [] {
+ auto motors = DCMotor::Vex775Pro(4);
+
+ // Carriage mass
+ constexpr auto m = 8_kg;
+
+ // Radius of pulley
+ constexpr auto r = 0.75_in;
+
+ // Gear ratio
+ constexpr double G = 14.67;
+
+ return frc::LinearSystemId::ElevatorSystem(motors, m, r, G);
+ }();
+ LinearQuadraticRegulator<2, 1> controller{plant, {0.1, 0.2}, {12.0}, 0.02_s};
+
+ controller.LatencyCompensate(plant, 0.02_s, 0.01_s);
+
+ EXPECT_NEAR(8.97115941, controller.K(0, 0), 1e-3);
+ EXPECT_NEAR(0.07904881, controller.K(0, 1), 1e-3);
+}
+
} // namespace frc
diff --git a/wpimath/src/test/native/cpp/controller/PIDInputOutputTest.cpp b/wpimath/src/test/native/cpp/controller/PIDInputOutputTest.cpp
new file mode 100644
index 0000000..379db9e
--- /dev/null
+++ b/wpimath/src/test/native/cpp/controller/PIDInputOutputTest.cpp
@@ -0,0 +1,51 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/controller/PIDController.h"
+#include "gtest/gtest.h"
+
+class PIDInputOutputTest : public testing::Test {
+ protected:
+ frc2::PIDController* controller;
+
+ void SetUp() override { controller = new frc2::PIDController(0, 0, 0); }
+
+ void TearDown() override { delete controller; }
+};
+
+TEST_F(PIDInputOutputTest, ContinuousInput) {
+ controller->SetP(1);
+ controller->EnableContinuousInput(-180, 180);
+ EXPECT_DOUBLE_EQ(controller->Calculate(-179, 179), -2);
+
+ controller->EnableContinuousInput(0, 360);
+ EXPECT_DOUBLE_EQ(controller->Calculate(1, 359), -2);
+}
+
+TEST_F(PIDInputOutputTest, ProportionalGainOutput) {
+ controller->SetP(4);
+
+ EXPECT_DOUBLE_EQ(-0.1, controller->Calculate(0.025, 0));
+}
+
+TEST_F(PIDInputOutputTest, IntegralGainOutput) {
+ controller->SetI(4);
+
+ double out = 0;
+
+ for (int i = 0; i < 5; i++) {
+ out = controller->Calculate(0.025, 0);
+ }
+
+ EXPECT_DOUBLE_EQ(-0.5 * controller->GetPeriod().value(), out);
+}
+
+TEST_F(PIDInputOutputTest, DerivativeGainOutput) {
+ controller->SetD(4);
+
+ controller->Calculate(0, 0);
+
+ EXPECT_DOUBLE_EQ(-10_ms / controller->GetPeriod(),
+ controller->Calculate(0.0025, 0));
+}
diff --git a/wpimath/src/test/native/cpp/controller/PIDToleranceTest.cpp b/wpimath/src/test/native/cpp/controller/PIDToleranceTest.cpp
new file mode 100644
index 0000000..0aec438
--- /dev/null
+++ b/wpimath/src/test/native/cpp/controller/PIDToleranceTest.cpp
@@ -0,0 +1,51 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/controller/PIDController.h"
+#include "gtest/gtest.h"
+
+static constexpr double kSetpoint = 50.0;
+static constexpr double kRange = 200;
+static constexpr double kTolerance = 10.0;
+
+TEST(PIDToleranceTest, InitialTolerance) {
+ frc2::PIDController controller{0.5, 0.0, 0.0};
+ controller.EnableContinuousInput(-kRange / 2, kRange / 2);
+
+ EXPECT_TRUE(controller.AtSetpoint());
+}
+
+TEST(PIDToleranceTest, AbsoluteTolerance) {
+ frc2::PIDController controller{0.5, 0.0, 0.0};
+ controller.EnableContinuousInput(-kRange / 2, kRange / 2);
+
+ EXPECT_TRUE(controller.AtSetpoint())
+ << "Error was not in tolerance when it should have been. Error was "
+ << controller.GetPositionError();
+
+ controller.SetTolerance(kTolerance);
+ controller.SetSetpoint(kSetpoint);
+
+ EXPECT_FALSE(controller.AtSetpoint())
+ << "Error was in tolerance when it should not have been. Error was "
+ << controller.GetPositionError();
+
+ controller.Calculate(0.0);
+
+ EXPECT_FALSE(controller.AtSetpoint())
+ << "Error was in tolerance when it should not have been. Error was "
+ << controller.GetPositionError();
+
+ controller.Calculate(kSetpoint + kTolerance / 2);
+
+ EXPECT_TRUE(controller.AtSetpoint())
+ << "Error was not in tolerance when it should have been. Error was "
+ << controller.GetPositionError();
+
+ controller.Calculate(kSetpoint + 10 * kTolerance);
+
+ EXPECT_FALSE(controller.AtSetpoint())
+ << "Error was in tolerance when it should not have been. Error was "
+ << controller.GetPositionError();
+}
diff --git a/wpimath/src/test/native/cpp/controller/ProfiledPIDInputOutputTest.cpp b/wpimath/src/test/native/cpp/controller/ProfiledPIDInputOutputTest.cpp
new file mode 100644
index 0000000..da402ae
--- /dev/null
+++ b/wpimath/src/test/native/cpp/controller/ProfiledPIDInputOutputTest.cpp
@@ -0,0 +1,117 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <wpi/numbers>
+
+#include "frc/controller/ProfiledPIDController.h"
+#include "gtest/gtest.h"
+#include "units/angle.h"
+#include "units/angular_acceleration.h"
+#include "units/angular_velocity.h"
+
+class ProfiledPIDInputOutputTest : public testing::Test {
+ protected:
+ frc::ProfiledPIDController<units::degrees>* controller;
+
+ void SetUp() override {
+ controller = new frc::ProfiledPIDController<units::degrees>(
+ 0, 0, 0, {360_deg_per_s, 180_deg_per_s_sq});
+ }
+
+ void TearDown() override { delete controller; }
+};
+
+TEST_F(ProfiledPIDInputOutputTest, ContinuousInput1) {
+ controller->SetP(1);
+ controller->EnableContinuousInput(-180_deg, 180_deg);
+
+ static constexpr units::degree_t kSetpoint{-179.0};
+ static constexpr units::degree_t kMeasurement{-179.0};
+ static constexpr units::degree_t kGoal{179.0};
+
+ controller->Reset(kSetpoint);
+ EXPECT_LT(controller->Calculate(kMeasurement, kGoal), 0.0);
+
+ // Error must be less than half the input range at all times
+ EXPECT_LT(units::math::abs(controller->GetSetpoint().position - kMeasurement),
+ 180_deg);
+}
+
+TEST_F(ProfiledPIDInputOutputTest, ContinuousInput2) {
+ controller->SetP(1);
+ controller->EnableContinuousInput(-units::radian_t{wpi::numbers::pi},
+ units::radian_t{wpi::numbers::pi});
+
+ static constexpr units::radian_t kSetpoint{-3.4826633343199735};
+ static constexpr units::radian_t kMeasurement{-3.1352207333939606};
+ static constexpr units::radian_t kGoal{-3.534162788601621};
+
+ controller->Reset(kSetpoint);
+ EXPECT_LT(controller->Calculate(kMeasurement, kGoal), 0.0);
+
+ // Error must be less than half the input range at all times
+ EXPECT_LT(units::math::abs(controller->GetSetpoint().position - kMeasurement),
+ units::radian_t{wpi::numbers::pi});
+}
+
+TEST_F(ProfiledPIDInputOutputTest, ContinuousInput3) {
+ controller->SetP(1);
+ controller->EnableContinuousInput(-units::radian_t{wpi::numbers::pi},
+ units::radian_t{wpi::numbers::pi});
+
+ static constexpr units::radian_t kSetpoint{-3.5176604690006377};
+ static constexpr units::radian_t kMeasurement{3.1191729343822456};
+ static constexpr units::radian_t kGoal{2.709680418117445};
+
+ controller->Reset(kSetpoint);
+ EXPECT_LT(controller->Calculate(kMeasurement, kGoal), 0.0);
+
+ // Error must be less than half the input range at all times
+ EXPECT_LT(units::math::abs(controller->GetSetpoint().position - kMeasurement),
+ units::radian_t{wpi::numbers::pi});
+}
+
+TEST_F(ProfiledPIDInputOutputTest, ContinuousInput4) {
+ controller->SetP(1);
+ controller->EnableContinuousInput(0_rad,
+ units::radian_t{2.0 * wpi::numbers::pi});
+
+ static constexpr units::radian_t kSetpoint{2.78};
+ static constexpr units::radian_t kMeasurement{3.12};
+ static constexpr units::radian_t kGoal{2.71};
+
+ controller->Reset(kSetpoint);
+ EXPECT_LT(controller->Calculate(kMeasurement, kGoal), 0.0);
+
+ // Error must be less than half the input range at all times
+ EXPECT_LT(units::math::abs(controller->GetSetpoint().position - kMeasurement),
+ units::radian_t{wpi::numbers::pi});
+}
+
+TEST_F(ProfiledPIDInputOutputTest, ProportionalGainOutput) {
+ controller->SetP(4);
+
+ EXPECT_DOUBLE_EQ(-0.1, controller->Calculate(0.025_deg, 0_deg));
+}
+
+TEST_F(ProfiledPIDInputOutputTest, IntegralGainOutput) {
+ controller->SetI(4);
+
+ double out = 0;
+
+ for (int i = 0; i < 5; i++) {
+ out = controller->Calculate(0.025_deg, 0_deg);
+ }
+
+ EXPECT_DOUBLE_EQ(-0.5 * controller->GetPeriod().value(), out);
+}
+
+TEST_F(ProfiledPIDInputOutputTest, DerivativeGainOutput) {
+ controller->SetD(4);
+
+ controller->Calculate(0_deg, 0_deg);
+
+ EXPECT_DOUBLE_EQ(-10_ms / controller->GetPeriod(),
+ controller->Calculate(0.0025_deg, 0_deg));
+}
diff --git a/wpimath/src/test/native/cpp/controller/RamseteControllerTest.cpp b/wpimath/src/test/native/cpp/controller/RamseteControllerTest.cpp
new file mode 100644
index 0000000..5e297f4
--- /dev/null
+++ b/wpimath/src/test/native/cpp/controller/RamseteControllerTest.cpp
@@ -0,0 +1,43 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/MathUtil.h"
+#include "frc/controller/RamseteController.h"
+#include "frc/trajectory/TrajectoryGenerator.h"
+#include "gtest/gtest.h"
+#include "units/math.h"
+
+#define EXPECT_NEAR_UNITS(val1, val2, eps) \
+ EXPECT_LE(units::math::abs(val1 - val2), eps)
+
+static constexpr units::meter_t kTolerance{1 / 12.0};
+static constexpr units::radian_t kAngularTolerance{2.0 * wpi::numbers::pi /
+ 180.0};
+
+TEST(RamseteControllerTest, ReachesReference) {
+ frc::RamseteController controller{2.0, 0.7};
+ frc::Pose2d robotPose{2.7_m, 23_m, frc::Rotation2d{0_deg}};
+
+ auto waypoints = std::vector{frc::Pose2d{2.75_m, 22.521_m, 0_rad},
+ frc::Pose2d{24.73_m, 19.68_m, 5.846_rad}};
+ auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
+ waypoints, {8.8_mps, 0.1_mps_sq});
+
+ constexpr auto kDt = 0.02_s;
+ auto totalTime = trajectory.TotalTime();
+ for (size_t i = 0; i < (totalTime / kDt).value(); ++i) {
+ auto state = trajectory.Sample(kDt * i);
+ auto [vx, vy, omega] = controller.Calculate(robotPose, state);
+ static_cast<void>(vy);
+
+ robotPose = robotPose.Exp(frc::Twist2d{vx * kDt, 0_m, omega * kDt});
+ }
+
+ auto& endPose = trajectory.States().back().pose;
+ EXPECT_NEAR_UNITS(endPose.X(), robotPose.X(), kTolerance);
+ EXPECT_NEAR_UNITS(endPose.Y(), robotPose.Y(), kTolerance);
+ EXPECT_NEAR_UNITS(frc::AngleModulus(endPose.Rotation().Radians() -
+ robotPose.Rotation().Radians()),
+ 0_rad, kAngularTolerance);
+}
diff --git a/wpimath/src/test/native/cpp/controller/SimpleMotorFeedforwardTest.cpp b/wpimath/src/test/native/cpp/controller/SimpleMotorFeedforwardTest.cpp
new file mode 100644
index 0000000..3cf944e
--- /dev/null
+++ b/wpimath/src/test/native/cpp/controller/SimpleMotorFeedforwardTest.cpp
@@ -0,0 +1,46 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <gtest/gtest.h>
+
+#include <cmath>
+
+#include "Eigen/Core"
+#include "frc/controller/LinearPlantInversionFeedforward.h"
+#include "frc/controller/SimpleMotorFeedforward.h"
+#include "units/acceleration.h"
+#include "units/length.h"
+#include "units/time.h"
+
+namespace frc {
+
+TEST(SimpleMotorFeedforwardTest, Calculate) {
+ double Ks = 0.5;
+ double Kv = 3.0;
+ double Ka = 0.6;
+ auto dt = 0.02_s;
+
+ Eigen::Matrix<double, 1, 1> A{-Kv / Ka};
+ Eigen::Matrix<double, 1, 1> B{1.0 / Ka};
+
+ frc::LinearPlantInversionFeedforward<1, 1> plantInversion{A, B, dt};
+ frc::SimpleMotorFeedforward<units::meter> simpleMotor{
+ units::volt_t{Ks}, units::volt_t{Kv} / 1_mps,
+ units::volt_t{Ka} / 1_mps_sq};
+
+ Eigen::Vector<double, 1> r{2.0};
+ Eigen::Vector<double, 1> nextR{3.0};
+
+ EXPECT_NEAR(37.524995834325161 + Ks,
+ simpleMotor.Calculate(2_mps, 3_mps, dt).value(), 0.002);
+ EXPECT_NEAR(plantInversion.Calculate(r, nextR)(0) + Ks,
+ simpleMotor.Calculate(2_mps, 3_mps, dt).value(), 0.002);
+
+ // These won't match exactly. It's just an approximation to make sure they're
+ // in the same ballpark.
+ EXPECT_NEAR(plantInversion.Calculate(r, nextR)(0) + Ks,
+ simpleMotor.Calculate(2_mps, 1_mps / dt).value(), 2.0);
+}
+
+} // namespace frc
diff --git a/wpimath/src/test/native/cpp/drake/discrete_algebraic_riccati_equation_test.cpp b/wpimath/src/test/native/cpp/drake/discrete_algebraic_riccati_equation_test.cpp
new file mode 100644
index 0000000..8631d6f
--- /dev/null
+++ b/wpimath/src/test/native/cpp/drake/discrete_algebraic_riccati_equation_test.cpp
@@ -0,0 +1,124 @@
+#include "drake/math/discrete_algebraic_riccati_equation.h"
+
+#include <Eigen/Eigenvalues>
+#include <gtest/gtest.h>
+
+#include "drake/common/test_utilities/eigen_matrix_compare.h"
+// #include "drake/math/autodiff.h"
+
+using Eigen::MatrixXd;
+
+namespace drake {
+namespace math {
+namespace {
+void SolveDAREandVerify(const Eigen::Ref<const MatrixXd>& A,
+ const Eigen::Ref<const MatrixXd>& B,
+ const Eigen::Ref<const MatrixXd>& Q,
+ const Eigen::Ref<const MatrixXd>& R) {
+ MatrixXd X = DiscreteAlgebraicRiccatiEquation(A, B, Q, R);
+ // Check that X is positive semi-definite.
+ EXPECT_TRUE(
+ CompareMatrices(X, X.transpose(), 1E-10, MatrixCompareType::absolute));
+ int n = X.rows();
+ Eigen::SelfAdjointEigenSolver<MatrixXd> es(X);
+ for (int i = 0; i < n; i++) {
+ EXPECT_GE(es.eigenvalues()[i], 0);
+ }
+ // Check that X is the solution to the discrete time ARE.
+ // clang-format off
+ MatrixXd Y =
+ A.transpose() * X * A
+ - X
+ - (A.transpose() * X * B * (B.transpose() * X * B + R).inverse()
+ * B.transpose() * X * A)
+ + Q;
+ // clang-format on
+ EXPECT_TRUE(CompareMatrices(Y, MatrixXd::Zero(n, n), 1E-10,
+ MatrixCompareType::absolute));
+}
+
+void SolveDAREandVerify(const Eigen::Ref<const MatrixXd>& A,
+ const Eigen::Ref<const MatrixXd>& B,
+ const Eigen::Ref<const MatrixXd>& Q,
+ const Eigen::Ref<const MatrixXd>& R,
+ const Eigen::Ref<const MatrixXd>& N) {
+ MatrixXd X = DiscreteAlgebraicRiccatiEquation(A, B, Q, R, N);
+ // Check that X is positive semi-definite.
+ EXPECT_TRUE(
+ CompareMatrices(X, X.transpose(), 1E-10, MatrixCompareType::absolute));
+ int n = X.rows();
+ Eigen::SelfAdjointEigenSolver<MatrixXd> es(X);
+ for (int i = 0; i < n; i++) {
+ EXPECT_GE(es.eigenvalues()[i], 0);
+ }
+ // Check that X is the solution to the discrete time ARE.
+ // clang-format off
+ MatrixXd Y =
+ A.transpose() * X * A
+ - X
+ - ((A.transpose() * X * B + N) * (B.transpose() * X * B + R).inverse()
+ * (B.transpose() * X * A + N.transpose()))
+ + Q;
+ // clang-format on
+ EXPECT_TRUE(CompareMatrices(Y, MatrixXd::Zero(n, n), 1E-10,
+ MatrixCompareType::absolute));
+}
+
+GTEST_TEST(DARE, SolveDAREandVerify) {
+ // Test 1: non-invertible A
+ // Example 2 of "On the Numerical Solution of the Discrete-Time Algebraic
+ // Riccati Equation"
+ int n1 = 4, m1 = 1;
+ MatrixXd A1(n1, n1), B1(n1, m1), Q1(n1, n1), R1(m1, m1);
+ A1 << 0.5, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0;
+ B1 << 0, 0, 0, 1;
+ Q1 << 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0;
+ R1 << 0.25;
+ SolveDAREandVerify(A1, B1, Q1, R1);
+
+ MatrixXd Aref1(n1, n1);
+ Aref1 << 0.25, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0;
+ SolveDAREandVerify(A1, B1, (A1 - Aref1).transpose() * Q1 * (A1 - Aref1),
+ B1.transpose() * Q1 * B1 + R1, (A1 - Aref1).transpose() * Q1 * B1);
+
+ // Test 2: invertible A
+ int n2 = 2, m2 = 1;
+ MatrixXd A2(n2, n2), B2(n2, m2), Q2(n2, n2), R2(m2, m2);
+ A2 << 1, 1, 0, 1;
+ B2 << 0, 1;
+ Q2 << 1, 0, 0, 0;
+ R2 << 0.3;
+ SolveDAREandVerify(A2, B2, Q2, R2);
+
+ MatrixXd Aref2(n2, n2);
+ Aref2 << 0.5, 1, 0, 1;
+ SolveDAREandVerify(A2, B2, (A2 - Aref2).transpose() * Q2 * (A2 - Aref2),
+ B2.transpose() * Q2 * B2 + R2, (A2 - Aref2).transpose() * Q2 * B2);
+
+ // Test 3: the first generalized eigenvalue of (S,T) is stable
+ int n3 = 2, m3 = 1;
+ MatrixXd A3(n3, n3), B3(n3, m3), Q3(n3, n3), R3(m3, m3);
+ A3 << 0, 1, 0, 0;
+ B3 << 0, 1;
+ Q3 << 1, 0, 0, 1;
+ R3 << 1;
+ SolveDAREandVerify(A3, B3, Q3, R3);
+
+ MatrixXd Aref3(n3, n3);
+ Aref3 << 0, 0.5, 0, 0;
+ SolveDAREandVerify(A3, B3, (A3 - Aref3).transpose() * Q3 * (A3 - Aref3),
+ B3.transpose() * Q3 * B3 + R3, (A3 - Aref3).transpose() * Q3 * B3);
+
+ // Test 4: A = B = Q = R = I_2 (2-by-2 identity matrix)
+ const Eigen::MatrixXd A4{Eigen::Matrix2d::Identity()};
+ const Eigen::MatrixXd B4{Eigen::Matrix2d::Identity()};
+ const Eigen::MatrixXd Q4{Eigen::Matrix2d::Identity()};
+ const Eigen::MatrixXd R4{Eigen::Matrix2d::Identity()};
+ SolveDAREandVerify(A4, B4, Q4, R4);
+
+ const Eigen::MatrixXd N4{Eigen::Matrix2d::Identity()};
+ SolveDAREandVerify(A4, B4, Q4, R4, N4);
+}
+} // namespace
+} // namespace math
+} // namespace drake
diff --git a/wpimath/src/test/native/cpp/drake/math/discrete_algebraic_riccati_equation_test.cpp b/wpimath/src/test/native/cpp/drake/math/discrete_algebraic_riccati_equation_test.cpp
deleted file mode 100644
index edcb772..0000000
--- a/wpimath/src/test/native/cpp/drake/math/discrete_algebraic_riccati_equation_test.cpp
+++ /dev/null
@@ -1,76 +0,0 @@
-#include "drake/math/discrete_algebraic_riccati_equation.h"
-
-#include <Eigen/Core>
-#include <Eigen/Eigenvalues>
-
-#include <gtest/gtest.h>
-
-#include "drake/common/test_utilities/eigen_matrix_compare.h"
-#include "drake/math/autodiff.h"
-
-using Eigen::MatrixXd;
-
-namespace drake {
-namespace math {
-namespace {
-void SolveDAREandVerify(const Eigen::Ref<const MatrixXd>& A,
- const Eigen::Ref<const MatrixXd>& B,
- const Eigen::Ref<const MatrixXd>& Q,
- const Eigen::Ref<const MatrixXd>& R) {
- MatrixXd X = DiscreteAlgebraicRiccatiEquation(A, B, Q, R);
- // Check that X is positive semi-definite.
- EXPECT_TRUE(
- CompareMatrices(X, X.transpose(), 1E-10, MatrixCompareType::absolute));
- int n = X.rows();
- Eigen::SelfAdjointEigenSolver<MatrixXd> es(X);
- for (int i = 0; i < n; i++) {
- EXPECT_GE(es.eigenvalues()[i], 0);
- }
- // Check that X is the solution to the discrete time ARE.
- MatrixXd Y = A.transpose() * X * A - X -
- A.transpose() * X * B * (B.transpose() * X * B + R).inverse() *
- B.transpose() * X * A +
- Q;
- EXPECT_TRUE(CompareMatrices(Y, MatrixXd::Zero(n, n), 1E-10,
- MatrixCompareType::absolute));
-}
-
-GTEST_TEST(DARE, SolveDAREandVerify) {
- // Test 1: non-invertible A
- // Example 2 of "On the Numerical Solution of the Discrete-Time Algebraic
- // Riccati Equation"
- int n1 = 4, m1 = 1;
- MatrixXd A1(n1, n1), B1(n1, m1), Q1(n1, n1), R1(m1, m1);
- A1 << 0.5, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0;
- B1 << 0, 0, 0, 1;
- Q1 << 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0;
- R1 << 0.25;
- SolveDAREandVerify(A1, B1, Q1, R1);
- // Test 2: invertible A
- int n2 = 2, m2 = 1;
- MatrixXd A2(n2, n2), B2(n2, m2), Q2(n2, n2), R2(m2, m2);
- A2 << 1, 1, 0, 1;
- B2 << 0, 1;
- Q2 << 1, 0, 0, 0;
- R2 << 0.3;
- SolveDAREandVerify(A2, B2, Q2, R2);
- // Test 3: the first generalized eigenvalue of (S,T) is stable
- int n3 = 2, m3 = 1;
- MatrixXd A3(n3, n3), B3(n3, m3), Q3(n3, n3), R3(m3, m3);
- A3 << 0, 1, 0, 0;
- B3 << 0, 1;
- Q3 << 1, 0, 0, 1;
- R3 << 1;
- SolveDAREandVerify(A3, B3, Q3, R3);
- // Test 4: A = B = Q = R = I_2 (2-by-2 identity matrix)
- int n4 = 2, m4 = 2;
- MatrixXd A4(n4, n4), B4(n4, m4), Q4(n4, n4), R4(m4, m4);
- A4 << 1, 0, 0, 1;
- B4 << 1, 0, 0, 1;
- Q4 << 1, 0, 0, 1;
- R4 << 1, 0, 0, 1;
- SolveDAREandVerify(A4, B4, Q4, R4);
-}
-} // namespace
-} // namespace math
-} // namespace drake
diff --git a/wpimath/src/test/native/cpp/estimator/AngleStatisticsTest.cpp b/wpimath/src/test/native/cpp/estimator/AngleStatisticsTest.cpp
new file mode 100644
index 0000000..ee1da7f
--- /dev/null
+++ b/wpimath/src/test/native/cpp/estimator/AngleStatisticsTest.cpp
@@ -0,0 +1,38 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <gtest/gtest.h>
+
+#include <wpi/numbers>
+
+#include "Eigen/Core"
+#include "frc/estimator/AngleStatistics.h"
+
+TEST(AngleStatisticsTest, Mean) {
+ Eigen::Matrix<double, 3, 3> sigmas{
+ {1, 1.2, 0},
+ {359 * wpi::numbers::pi / 180, 3 * wpi::numbers::pi / 180, 0},
+ {1, 2, 0}};
+ // Weights need to produce the mean of the sigmas
+ Eigen::Vector3d weights;
+ weights.fill(1.0 / sigmas.cols());
+
+ EXPECT_TRUE(Eigen::Vector3d(0.7333333, 0.01163323, 1)
+ .isApprox(frc::AngleMean<3, 1>(sigmas, weights, 1), 1e-3));
+}
+
+TEST(AngleStatisticsTest, Residual) {
+ Eigen::Vector3d a{1, 1 * wpi::numbers::pi / 180, 2};
+ Eigen::Vector3d b{1, 359 * wpi::numbers::pi / 180, 1};
+
+ EXPECT_TRUE(frc::AngleResidual<3>(a, b, 1).isApprox(
+ Eigen::Vector3d{0, 2 * wpi::numbers::pi / 180, 1}));
+}
+
+TEST(AngleStatisticsTest, Add) {
+ Eigen::Vector3d a{1, 1 * wpi::numbers::pi / 180, 2};
+ Eigen::Vector3d b{1, 359 * wpi::numbers::pi / 180, 1};
+
+ EXPECT_TRUE(frc::AngleAdd<3>(a, b, 1).isApprox(Eigen::Vector3d{2, 0, 3}));
+}
diff --git a/wpimath/src/test/native/cpp/estimator/DifferentialDrivePoseEstimatorTest.cpp b/wpimath/src/test/native/cpp/estimator/DifferentialDrivePoseEstimatorTest.cpp
new file mode 100644
index 0000000..4a854fd
--- /dev/null
+++ b/wpimath/src/test/native/cpp/estimator/DifferentialDrivePoseEstimatorTest.cpp
@@ -0,0 +1,98 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <limits>
+#include <random>
+
+#include "frc/StateSpaceUtil.h"
+#include "frc/estimator/DifferentialDrivePoseEstimator.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/geometry/Rotation2d.h"
+#include "frc/kinematics/DifferentialDriveKinematics.h"
+#include "frc/kinematics/DifferentialDriveOdometry.h"
+#include "frc/trajectory/TrajectoryGenerator.h"
+#include "gtest/gtest.h"
+#include "units/angle.h"
+#include "units/length.h"
+#include "units/time.h"
+
+TEST(DifferentialDrivePoseEstimatorTest, Accuracy) {
+ frc::DifferentialDrivePoseEstimator estimator{frc::Rotation2d(),
+ frc::Pose2d(),
+ {0.02, 0.02, 0.01, 0.02, 0.02},
+ {0.01, 0.01, 0.001},
+ {0.1, 0.1, 0.01}};
+
+ frc::Trajectory trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
+ std::vector{frc::Pose2d(0_m, 0_m, frc::Rotation2d(45_deg)),
+ frc::Pose2d(3_m, 0_m, frc::Rotation2d(-90_deg)),
+ frc::Pose2d(0_m, 0_m, frc::Rotation2d(135_deg)),
+ frc::Pose2d(-3_m, 0_m, frc::Rotation2d(-90_deg)),
+ frc::Pose2d(0_m, 0_m, frc::Rotation2d(45_deg))},
+ frc::TrajectoryConfig(10_mps, 5.0_mps_sq));
+
+ frc::DifferentialDriveKinematics kinematics{1.0_m};
+ frc::DifferentialDriveOdometry odometry{frc::Rotation2d()};
+
+ std::default_random_engine generator;
+ std::normal_distribution<double> distribution(0.0, 1.0);
+
+ units::second_t dt = 0.02_s;
+ units::second_t t = 0.0_s;
+
+ units::meter_t leftDistance = 0_m;
+ units::meter_t rightDistance = 0_m;
+
+ units::second_t kVisionUpdateRate = 0.1_s;
+ frc::Pose2d lastVisionPose;
+ units::second_t lastVisionUpdateTime{-std::numeric_limits<double>::max()};
+
+ double maxError = -std::numeric_limits<double>::max();
+ double errorSum = 0;
+
+ while (t <= trajectory.TotalTime()) {
+ auto groundTruthState = trajectory.Sample(t);
+ auto input = kinematics.ToWheelSpeeds(
+ {groundTruthState.velocity, 0_mps,
+ groundTruthState.velocity * groundTruthState.curvature});
+
+ if (lastVisionUpdateTime + kVisionUpdateRate < t) {
+ if (lastVisionPose != frc::Pose2d()) {
+ estimator.AddVisionMeasurement(lastVisionPose, lastVisionUpdateTime);
+ }
+ lastVisionPose =
+ groundTruthState.pose +
+ frc::Transform2d(
+ frc::Translation2d(distribution(generator) * 0.1 * 1_m,
+ distribution(generator) * 0.1 * 1_m),
+ frc::Rotation2d(distribution(generator) * 0.01 * 1_rad));
+
+ lastVisionUpdateTime = t;
+ }
+
+ leftDistance += input.left * distribution(generator) * 0.01 * dt;
+ rightDistance += input.right * distribution(generator) * 0.01 * dt;
+
+ auto xhat = estimator.UpdateWithTime(
+ t,
+ groundTruthState.pose.Rotation() +
+ frc::Rotation2d(units::radian_t(distribution(generator) * 0.001)),
+ input, leftDistance, rightDistance);
+
+ double error = groundTruthState.pose.Translation()
+ .Distance(xhat.Translation())
+ .value();
+
+ if (error > maxError) {
+ maxError = error;
+ }
+ errorSum += error;
+
+ t += dt;
+ }
+
+ EXPECT_NEAR(0.0, errorSum / (trajectory.TotalTime().value() / dt.value()),
+ 0.2);
+ EXPECT_NEAR(0.0, maxError, 0.4);
+}
diff --git a/wpimath/src/test/native/cpp/estimator/ExtendedKalmanFilterTest.cpp b/wpimath/src/test/native/cpp/estimator/ExtendedKalmanFilterTest.cpp
index 387593b..6d51185 100644
--- a/wpimath/src/test/native/cpp/estimator/ExtendedKalmanFilterTest.cpp
+++ b/wpimath/src/test/native/cpp/estimator/ExtendedKalmanFilterTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -21,8 +18,8 @@
namespace {
-Eigen::Matrix<double, 5, 1> Dynamics(const Eigen::Matrix<double, 5, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 5> Dynamics(const Eigen::Vector<double, 5>& x,
+ const Eigen::Vector<double, 2>& u) {
auto motors = frc::DCMotor::CIM(2);
// constexpr double Glow = 15.32; // Low gear ratio
@@ -43,36 +40,26 @@
units::volt_t Vl{u(0)};
units::volt_t Vr{u(1)};
- Eigen::Matrix<double, 5, 1> result;
auto v = 0.5 * (vl + vr);
- result(0) = v.to<double>() * std::cos(x(2));
- result(1) = v.to<double>() * std::sin(x(2));
- result(2) = ((vr - vl) / (2.0 * rb)).to<double>();
- result(3) =
- k1.to<double>() * ((C1 * vl).to<double>() + (C2 * Vl).to<double>()) +
- k2.to<double>() * ((C1 * vr).to<double>() + (C2 * Vr).to<double>());
- result(4) =
- k2.to<double>() * ((C1 * vl).to<double>() + (C2 * Vl).to<double>()) +
- k1.to<double>() * ((C1 * vr).to<double>() + (C2 * Vr).to<double>());
- return result;
+ return Eigen::Vector<double, 5>{
+ v.value() * std::cos(x(2)), v.value() * std::sin(x(2)),
+ ((vr - vl) / (2.0 * rb)).value(),
+ k1.value() * ((C1 * vl).value() + (C2 * Vl).value()) +
+ k2.value() * ((C1 * vr).value() + (C2 * Vr).value()),
+ k2.value() * ((C1 * vl).value() + (C2 * Vl).value()) +
+ k1.value() * ((C1 * vr).value() + (C2 * Vr).value())};
}
-Eigen::Matrix<double, 3, 1> LocalMeasurementModel(
- const Eigen::Matrix<double, 5, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 3> LocalMeasurementModel(
+ const Eigen::Vector<double, 5>& x, const Eigen::Vector<double, 2>& u) {
static_cast<void>(u);
- Eigen::Matrix<double, 3, 1> y;
- y << x(2), x(3), x(4);
- return y;
+ return Eigen::Vector<double, 3>{x(2), x(3), x(4)};
}
-Eigen::Matrix<double, 5, 1> GlobalMeasurementModel(
- const Eigen::Matrix<double, 5, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 5> GlobalMeasurementModel(
+ const Eigen::Vector<double, 5>& x, const Eigen::Vector<double, 2>& u) {
static_cast<void>(u);
- Eigen::Matrix<double, 5, 1> y;
- y << x(0), x(1), x(2), x(3), x(4);
- return y;
+ return Eigen::Vector<double, 5>{x(0), x(1), x(2), x(3), x(4)};
}
} // namespace
@@ -84,8 +71,7 @@
{0.5, 0.5, 10.0, 1.0, 1.0},
{0.0001, 0.01, 0.01},
dt};
- Eigen::Matrix<double, 2, 1> u;
- u << 12.0, 12.0;
+ Eigen::Vector<double, 2> u{12.0, 12.0};
observer.Predict(u, dt);
auto localY = LocalMeasurementModel(observer.Xhat(), u);
@@ -112,41 +98,37 @@
auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
waypoints, {8.8_mps, 0.1_mps_sq});
- Eigen::Matrix<double, 5, 1> r = Eigen::Matrix<double, 5, 1>::Zero();
+ Eigen::Vector<double, 5> r = Eigen::Vector<double, 5>::Zero();
+ Eigen::Vector<double, 2> u = Eigen::Vector<double, 2>::Zero();
- Eigen::Matrix<double, 5, 1> nextR;
- Eigen::Matrix<double, 2, 1> u = Eigen::Matrix<double, 2, 1>::Zero();
+ auto B = frc::NumericalJacobianU<5, 5, 2>(Dynamics,
+ Eigen::Vector<double, 5>::Zero(),
+ Eigen::Vector<double, 2>::Zero());
- auto B = frc::NumericalJacobianU<5, 5, 2>(
- Dynamics, Eigen::Matrix<double, 5, 1>::Zero(),
- Eigen::Matrix<double, 2, 1>::Zero());
-
- observer.SetXhat(frc::MakeMatrix<5, 1>(
- trajectory.InitialPose().Translation().X().to<double>(),
- trajectory.InitialPose().Translation().Y().to<double>(),
- trajectory.InitialPose().Rotation().Radians().to<double>(), 0.0, 0.0));
+ observer.SetXhat(Eigen::Vector<double, 5>{
+ trajectory.InitialPose().Translation().X().value(),
+ trajectory.InitialPose().Translation().Y().value(),
+ trajectory.InitialPose().Rotation().Radians().value(), 0.0, 0.0});
auto totalTime = trajectory.TotalTime();
- for (size_t i = 0; i < (totalTime / dt).to<double>(); ++i) {
+ for (size_t i = 0; i < (totalTime / dt).value(); ++i) {
auto ref = trajectory.Sample(dt * i);
units::meters_per_second_t vl =
- ref.velocity * (1 - (ref.curvature * rb).to<double>());
+ ref.velocity * (1 - (ref.curvature * rb).value());
units::meters_per_second_t vr =
- ref.velocity * (1 + (ref.curvature * rb).to<double>());
+ ref.velocity * (1 + (ref.curvature * rb).value());
- nextR(0) = ref.pose.Translation().X().to<double>();
- nextR(1) = ref.pose.Translation().Y().to<double>();
- nextR(2) = ref.pose.Rotation().Radians().to<double>();
- nextR(3) = vl.to<double>();
- nextR(4) = vr.to<double>();
+ Eigen::Vector<double, 5> nextR{
+ ref.pose.Translation().X().value(), ref.pose.Translation().Y().value(),
+ ref.pose.Rotation().Radians().value(), vl.value(), vr.value()};
auto localY =
- LocalMeasurementModel(nextR, Eigen::Matrix<double, 2, 1>::Zero());
+ LocalMeasurementModel(nextR, Eigen::Vector<double, 2>::Zero());
observer.Correct(u, localY + frc::MakeWhiteNoiseVector(0.0001, 0.5, 0.5));
- Eigen::Matrix<double, 5, 1> rdot = (nextR - r) / dt.to<double>();
- u = B.householderQr().solve(
- rdot - Dynamics(r, Eigen::Matrix<double, 2, 1>::Zero()));
+ Eigen::Vector<double, 5> rdot = (nextR - r) / dt.value();
+ u = B.householderQr().solve(rdot -
+ Dynamics(r, Eigen::Vector<double, 2>::Zero()));
observer.Predict(u, dt);
@@ -161,12 +143,12 @@
observer.Correct<5>(u, globalY, GlobalMeasurementModel, R);
auto finalPosition = trajectory.Sample(trajectory.TotalTime());
- ASSERT_NEAR(finalPosition.pose.Translation().X().template to<double>(),
- observer.Xhat(0), 1.0);
- ASSERT_NEAR(finalPosition.pose.Translation().Y().template to<double>(),
- observer.Xhat(1), 1.0);
- ASSERT_NEAR(finalPosition.pose.Rotation().Radians().template to<double>(),
- observer.Xhat(2), 1.0);
+ ASSERT_NEAR(finalPosition.pose.Translation().X().value(), observer.Xhat(0),
+ 1.0);
+ ASSERT_NEAR(finalPosition.pose.Translation().Y().value(), observer.Xhat(1),
+ 1.0);
+ ASSERT_NEAR(finalPosition.pose.Rotation().Radians().value(), observer.Xhat(2),
+ 1.0);
ASSERT_NEAR(0.0, observer.Xhat(3), 1.0);
ASSERT_NEAR(0.0, observer.Xhat(4), 1.0);
}
diff --git a/wpimath/src/test/native/cpp/estimator/KalmanFilterTest.cpp b/wpimath/src/test/native/cpp/estimator/KalmanFilterTest.cpp
index 53c54ef..fc373ca 100644
--- a/wpimath/src/test/native/cpp/estimator/KalmanFilterTest.cpp
+++ b/wpimath/src/test/native/cpp/estimator/KalmanFilterTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
diff --git a/wpimath/src/test/native/cpp/estimator/MecanumDrivePoseEstimatorTest.cpp b/wpimath/src/test/native/cpp/estimator/MecanumDrivePoseEstimatorTest.cpp
new file mode 100644
index 0000000..881d4e8
--- /dev/null
+++ b/wpimath/src/test/native/cpp/estimator/MecanumDrivePoseEstimatorTest.cpp
@@ -0,0 +1,89 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <limits>
+#include <random>
+
+#include "frc/estimator/MecanumDrivePoseEstimator.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/kinematics/MecanumDriveKinematics.h"
+#include "frc/kinematics/MecanumDriveOdometry.h"
+#include "frc/trajectory/TrajectoryGenerator.h"
+#include "gtest/gtest.h"
+
+TEST(MecanumDrivePoseEstimatorTest, Accuracy) {
+ frc::MecanumDriveKinematics kinematics{
+ frc::Translation2d{1_m, 1_m}, frc::Translation2d{1_m, -1_m},
+ frc::Translation2d{-1_m, -1_m}, frc::Translation2d{-1_m, 1_m}};
+
+ frc::MecanumDrivePoseEstimator estimator{
+ frc::Rotation2d(), frc::Pose2d(), kinematics,
+ {0.1, 0.1, 0.1}, {0.05}, {0.1, 0.1, 0.1}};
+
+ frc::MecanumDriveOdometry odometry{kinematics, frc::Rotation2d()};
+
+ frc::Trajectory trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
+ std::vector{frc::Pose2d(0_m, 0_m, frc::Rotation2d(45_deg)),
+ frc::Pose2d(3_m, 0_m, frc::Rotation2d(-90_deg)),
+ frc::Pose2d(0_m, 0_m, frc::Rotation2d(135_deg)),
+ frc::Pose2d(-3_m, 0_m, frc::Rotation2d(-90_deg)),
+ frc::Pose2d(0_m, 0_m, frc::Rotation2d(45_deg))},
+ frc::TrajectoryConfig(5.0_mps, 2.0_mps_sq));
+
+ std::default_random_engine generator;
+ std::normal_distribution<double> distribution(0.0, 1.0);
+
+ units::second_t dt = 0.02_s;
+ units::second_t t = 0_s;
+
+ units::second_t kVisionUpdateRate = 0.1_s;
+ frc::Pose2d lastVisionPose;
+ units::second_t lastVisionUpdateTime{-std::numeric_limits<double>::max()};
+
+ std::vector<frc::Pose2d> visionPoses;
+
+ double maxError = -std::numeric_limits<double>::max();
+ double errorSum = 0;
+
+ while (t < trajectory.TotalTime()) {
+ frc::Trajectory::State groundTruthState = trajectory.Sample(t);
+
+ if (lastVisionUpdateTime + kVisionUpdateRate < t) {
+ if (lastVisionPose != frc::Pose2d()) {
+ estimator.AddVisionMeasurement(lastVisionPose, lastVisionUpdateTime);
+ }
+ lastVisionPose =
+ groundTruthState.pose +
+ frc::Transform2d(
+ frc::Translation2d(distribution(generator) * 0.1_m,
+ distribution(generator) * 0.1_m),
+ frc::Rotation2d(distribution(generator) * 0.1 * 1_rad));
+ visionPoses.push_back(lastVisionPose);
+ lastVisionUpdateTime = t;
+ }
+
+ auto wheelSpeeds = kinematics.ToWheelSpeeds(
+ {groundTruthState.velocity, 0_mps,
+ groundTruthState.velocity * groundTruthState.curvature});
+
+ auto xhat = estimator.UpdateWithTime(
+ t,
+ groundTruthState.pose.Rotation() +
+ frc::Rotation2d(distribution(generator) * 0.05_rad),
+ wheelSpeeds);
+ double error = groundTruthState.pose.Translation()
+ .Distance(xhat.Translation())
+ .value();
+
+ if (error > maxError) {
+ maxError = error;
+ }
+ errorSum += error;
+
+ t += dt;
+ }
+
+ EXPECT_LT(errorSum / (trajectory.TotalTime().value() / dt.value()), 0.2);
+ EXPECT_LT(maxError, 0.4);
+}
diff --git a/wpimath/src/test/native/cpp/estimator/MerweScaledSigmaPointsTest.cpp b/wpimath/src/test/native/cpp/estimator/MerweScaledSigmaPointsTest.cpp
index ace5e79..c012435 100644
--- a/wpimath/src/test/native/cpp/estimator/MerweScaledSigmaPointsTest.cpp
+++ b/wpimath/src/test/native/cpp/estimator/MerweScaledSigmaPointsTest.cpp
@@ -1,41 +1,37 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
-#include "frc/StateSpaceUtil.h"
#include "frc/estimator/MerweScaledSigmaPoints.h"
-namespace drake {
-namespace math {
+namespace drake::math {
namespace {
-TEST(MerweScaledSigmaPointsTest, TestZeroMean) {
+TEST(MerweScaledSigmaPointsTest, ZeroMean) {
frc::MerweScaledSigmaPoints<2> sigmaPoints;
- auto points =
- sigmaPoints.SigmaPoints(frc::MakeMatrix<2, 1>(0.0, 0.0),
- frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 1.0));
+ auto points = sigmaPoints.SigmaPoints(
+ Eigen::Vector<double, 2>{0.0, 0.0},
+ Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 1.0}});
EXPECT_TRUE(
- (points - frc::MakeMatrix<2, 5>(0.0, 0.00173205, 0.0, -0.00173205, 0.0,
- 0.0, 0.0, 0.00173205, 0.0, -0.00173205))
+ (points -
+ Eigen::Matrix<double, 2, 5>{{0.0, 0.00173205, 0.0, -0.00173205, 0.0},
+ {0.0, 0.0, 0.00173205, 0.0, -0.00173205}})
.norm() < 1e-3);
}
-TEST(MerweScaledSigmaPointsTest, TestNonzeroMean) {
+TEST(MerweScaledSigmaPointsTest, NonzeroMean) {
frc::MerweScaledSigmaPoints<2> sigmaPoints;
- auto points =
- sigmaPoints.SigmaPoints(frc::MakeMatrix<2, 1>(1.0, 2.0),
- frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 10.0));
+ auto points = sigmaPoints.SigmaPoints(
+ Eigen::Vector<double, 2>{1.0, 2.0},
+ Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 10.0}});
EXPECT_TRUE(
- (points - frc::MakeMatrix<2, 5>(1.0, 1.00173205, 1.0, 0.998268, 1.0, 2.0,
- 2.0, 2.00548, 2.0, 1.99452))
+ (points -
+ Eigen::Matrix<double, 2, 5>{{1.0, 1.00173205, 1.0, 0.998268, 1.0},
+ {2.0, 2.0, 2.00548, 2.0, 1.99452}})
.norm() < 1e-3);
}
} // namespace
-} // namespace math
-} // namespace drake
+} // namespace drake::math
diff --git a/wpimath/src/test/native/cpp/estimator/SwerveDrivePoseEstimatorTest.cpp b/wpimath/src/test/native/cpp/estimator/SwerveDrivePoseEstimatorTest.cpp
new file mode 100644
index 0000000..ee01f6f
--- /dev/null
+++ b/wpimath/src/test/native/cpp/estimator/SwerveDrivePoseEstimatorTest.cpp
@@ -0,0 +1,89 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <limits>
+#include <random>
+
+#include "frc/estimator/SwerveDrivePoseEstimator.h"
+#include "frc/geometry/Pose2d.h"
+#include "frc/kinematics/SwerveDriveKinematics.h"
+#include "frc/kinematics/SwerveDriveOdometry.h"
+#include "frc/trajectory/TrajectoryGenerator.h"
+#include "gtest/gtest.h"
+
+TEST(SwerveDrivePoseEstimatorTest, Accuracy) {
+ frc::SwerveDriveKinematics<4> kinematics{
+ frc::Translation2d{1_m, 1_m}, frc::Translation2d{1_m, -1_m},
+ frc::Translation2d{-1_m, -1_m}, frc::Translation2d{-1_m, 1_m}};
+
+ frc::SwerveDrivePoseEstimator<4> estimator{
+ frc::Rotation2d(), frc::Pose2d(), kinematics,
+ {0.1, 0.1, 0.1}, {0.05}, {0.1, 0.1, 0.1}};
+
+ frc::SwerveDriveOdometry<4> odometry{kinematics, frc::Rotation2d()};
+
+ frc::Trajectory trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
+ std::vector{frc::Pose2d(0_m, 0_m, frc::Rotation2d(45_deg)),
+ frc::Pose2d(3_m, 0_m, frc::Rotation2d(-90_deg)),
+ frc::Pose2d(0_m, 0_m, frc::Rotation2d(135_deg)),
+ frc::Pose2d(-3_m, 0_m, frc::Rotation2d(-90_deg)),
+ frc::Pose2d(0_m, 0_m, frc::Rotation2d(45_deg))},
+ frc::TrajectoryConfig(5.0_mps, 2.0_mps_sq));
+
+ std::default_random_engine generator;
+ std::normal_distribution<double> distribution(0.0, 1.0);
+
+ units::second_t dt = 0.02_s;
+ units::second_t t = 0_s;
+
+ units::second_t kVisionUpdateRate = 0.1_s;
+ frc::Pose2d lastVisionPose;
+ units::second_t lastVisionUpdateTime{-std::numeric_limits<double>::max()};
+
+ std::vector<frc::Pose2d> visionPoses;
+
+ double maxError = -std::numeric_limits<double>::max();
+ double errorSum = 0;
+
+ while (t < trajectory.TotalTime()) {
+ frc::Trajectory::State groundTruthState = trajectory.Sample(t);
+
+ if (lastVisionUpdateTime + kVisionUpdateRate < t) {
+ if (lastVisionPose != frc::Pose2d()) {
+ estimator.AddVisionMeasurement(lastVisionPose, lastVisionUpdateTime);
+ }
+ lastVisionPose =
+ groundTruthState.pose +
+ frc::Transform2d(
+ frc::Translation2d(distribution(generator) * 0.1_m,
+ distribution(generator) * 0.1_m),
+ frc::Rotation2d(distribution(generator) * 0.1 * 1_rad));
+ visionPoses.push_back(lastVisionPose);
+ lastVisionUpdateTime = t;
+ }
+
+ auto moduleStates = kinematics.ToSwerveModuleStates(
+ {groundTruthState.velocity, 0_mps,
+ groundTruthState.velocity * groundTruthState.curvature});
+
+ auto xhat = estimator.UpdateWithTime(
+ t,
+ groundTruthState.pose.Rotation() +
+ frc::Rotation2d(distribution(generator) * 0.05_rad),
+ moduleStates[0], moduleStates[1], moduleStates[2], moduleStates[3]);
+ double error = groundTruthState.pose.Translation()
+ .Distance(xhat.Translation())
+ .value();
+
+ if (error > maxError) {
+ maxError = error;
+ }
+ errorSum += error;
+
+ t += dt;
+ }
+
+ EXPECT_LT(errorSum / (trajectory.TotalTime().value() / dt.value()), 0.2);
+ EXPECT_LT(maxError, 0.4);
+}
diff --git a/wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp b/wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp
index 0017b42..9665442 100644
--- a/wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp
+++ b/wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -13,17 +10,18 @@
#include "Eigen/Core"
#include "Eigen/QR"
#include "frc/StateSpaceUtil.h"
+#include "frc/estimator/AngleStatistics.h"
#include "frc/estimator/UnscentedKalmanFilter.h"
+#include "frc/system/NumericalIntegration.h"
#include "frc/system/NumericalJacobian.h"
-#include "frc/system/RungeKutta.h"
#include "frc/system/plant/DCMotor.h"
#include "frc/trajectory/TrajectoryGenerator.h"
#include "units/moment_of_inertia.h"
namespace {
-Eigen::Matrix<double, 5, 1> Dynamics(const Eigen::Matrix<double, 5, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 5> Dynamics(const Eigen::Vector<double, 5>& x,
+ const Eigen::Vector<double, 2>& u) {
auto motors = frc::DCMotor::CIM(2);
// constexpr double Glow = 15.32; // Low gear ratio
@@ -44,49 +42,38 @@
units::volt_t Vl{u(0)};
units::volt_t Vr{u(1)};
- Eigen::Matrix<double, 5, 1> result;
auto v = 0.5 * (vl + vr);
- result(0) = v.to<double>() * std::cos(x(2));
- result(1) = v.to<double>() * std::sin(x(2));
- result(2) = ((vr - vl) / (2.0 * rb)).to<double>();
- result(3) =
- k1.to<double>() * ((C1 * vl).to<double>() + (C2 * Vl).to<double>()) +
- k2.to<double>() * ((C1 * vr).to<double>() + (C2 * Vr).to<double>());
- result(4) =
- k2.to<double>() * ((C1 * vl).to<double>() + (C2 * Vl).to<double>()) +
- k1.to<double>() * ((C1 * vr).to<double>() + (C2 * Vr).to<double>());
- return result;
+ return Eigen::Vector<double, 5>{
+ v.value() * std::cos(x(2)), v.value() * std::sin(x(2)),
+ ((vr - vl) / (2.0 * rb)).value(),
+ k1.value() * ((C1 * vl).value() + (C2 * Vl).value()) +
+ k2.value() * ((C1 * vr).value() + (C2 * Vr).value()),
+ k2.value() * ((C1 * vl).value() + (C2 * Vl).value()) +
+ k1.value() * ((C1 * vr).value() + (C2 * Vr).value())};
}
-Eigen::Matrix<double, 3, 1> LocalMeasurementModel(
- const Eigen::Matrix<double, 5, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 3> LocalMeasurementModel(
+ const Eigen::Vector<double, 5>& x, const Eigen::Vector<double, 2>& u) {
static_cast<void>(u);
- Eigen::Matrix<double, 3, 1> y;
- y << x(2), x(3), x(4);
- return y;
+ return Eigen::Vector<double, 3>{x(2), x(3), x(4)};
}
-Eigen::Matrix<double, 5, 1> GlobalMeasurementModel(
- const Eigen::Matrix<double, 5, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 5> GlobalMeasurementModel(
+ const Eigen::Vector<double, 5>& x, const Eigen::Vector<double, 2>& u) {
static_cast<void>(u);
- Eigen::Matrix<double, 5, 1> y;
- y << x(0), x(1), x(2), x(3), x(4);
- return y;
+ return Eigen::Vector<double, 5>{x(0), x(1), x(2), x(3), x(4)};
}
} // namespace
TEST(UnscentedKalmanFilterTest, Init) {
- constexpr auto dt = 0.00505_s;
+ constexpr auto dt = 5_ms;
frc::UnscentedKalmanFilter<5, 2, 3> observer{Dynamics,
LocalMeasurementModel,
{0.5, 0.5, 10.0, 1.0, 1.0},
{0.0001, 0.01, 0.01},
dt};
- Eigen::Matrix<double, 2, 1> u;
- u << 12.0, 12.0;
+ Eigen::Vector<double, 2> u{12.0, 12.0};
observer.Predict(u, dt);
auto localY = LocalMeasurementModel(observer.Xhat(), u);
@@ -94,11 +81,13 @@
auto globalY = GlobalMeasurementModel(observer.Xhat(), u);
auto R = frc::MakeCovMatrix(0.01, 0.01, 0.0001, 0.01, 0.01);
- observer.Correct<5>(u, globalY, GlobalMeasurementModel, R);
+ observer.Correct<5>(u, globalY, GlobalMeasurementModel, R,
+ frc::AngleMean<5, 5>(2), frc::AngleResidual<5>(2),
+ frc::AngleResidual<5>(2), frc::AngleAdd<5>(2));
}
TEST(UnscentedKalmanFilterTest, Convergence) {
- constexpr auto dt = 0.00505_s;
+ constexpr auto dt = 5_ms;
constexpr auto rb = 0.8382_m / 2.0; // Robot radius
frc::UnscentedKalmanFilter<5, 2, 3> observer{Dynamics,
@@ -113,48 +102,44 @@
auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
waypoints, {8.8_mps, 0.1_mps_sq});
- Eigen::Matrix<double, 5, 1> r = Eigen::Matrix<double, 5, 1>::Zero();
+ Eigen::Vector<double, 5> r = Eigen::Vector<double, 5>::Zero();
+ Eigen::Vector<double, 2> u = Eigen::Vector<double, 2>::Zero();
- Eigen::Matrix<double, 5, 1> nextR;
- Eigen::Matrix<double, 2, 1> u = Eigen::Matrix<double, 2, 1>::Zero();
+ auto B = frc::NumericalJacobianU<5, 5, 2>(Dynamics,
+ Eigen::Vector<double, 5>::Zero(),
+ Eigen::Vector<double, 2>::Zero());
- auto B = frc::NumericalJacobianU<5, 5, 2>(
- Dynamics, Eigen::Matrix<double, 5, 1>::Zero(),
- Eigen::Matrix<double, 2, 1>::Zero());
-
- observer.SetXhat(frc::MakeMatrix<5, 1>(
- trajectory.InitialPose().Translation().X().to<double>(),
- trajectory.InitialPose().Translation().Y().to<double>(),
- trajectory.InitialPose().Rotation().Radians().to<double>(), 0.0, 0.0));
+ observer.SetXhat(Eigen::Vector<double, 5>{
+ trajectory.InitialPose().Translation().X().value(),
+ trajectory.InitialPose().Translation().Y().value(),
+ trajectory.InitialPose().Rotation().Radians().value(), 0.0, 0.0});
auto trueXhat = observer.Xhat();
auto totalTime = trajectory.TotalTime();
- for (size_t i = 0; i < (totalTime / dt).to<double>(); ++i) {
+ for (size_t i = 0; i < (totalTime / dt).value(); ++i) {
auto ref = trajectory.Sample(dt * i);
units::meters_per_second_t vl =
- ref.velocity * (1 - (ref.curvature * rb).to<double>());
+ ref.velocity * (1 - (ref.curvature * rb).value());
units::meters_per_second_t vr =
- ref.velocity * (1 + (ref.curvature * rb).to<double>());
+ ref.velocity * (1 + (ref.curvature * rb).value());
- nextR(0) = ref.pose.Translation().X().to<double>();
- nextR(1) = ref.pose.Translation().Y().to<double>();
- nextR(2) = ref.pose.Rotation().Radians().to<double>();
- nextR(3) = vl.to<double>();
- nextR(4) = vr.to<double>();
+ Eigen::Vector<double, 5> nextR{
+ ref.pose.Translation().X().value(), ref.pose.Translation().Y().value(),
+ ref.pose.Rotation().Radians().value(), vl.value(), vr.value()};
auto localY =
- LocalMeasurementModel(trueXhat, Eigen::Matrix<double, 2, 1>::Zero());
+ LocalMeasurementModel(trueXhat, Eigen::Vector<double, 2>::Zero());
observer.Correct(u, localY + frc::MakeWhiteNoiseVector(0.0001, 0.5, 0.5));
- Eigen::Matrix<double, 5, 1> rdot = (nextR - r) / dt.to<double>();
- u = B.householderQr().solve(
- rdot - Dynamics(r, Eigen::Matrix<double, 2, 1>::Zero()));
+ Eigen::Vector<double, 5> rdot = (nextR - r) / dt.value();
+ u = B.householderQr().solve(rdot -
+ Dynamics(r, Eigen::Vector<double, 2>::Zero()));
observer.Predict(u, dt);
r = nextR;
- trueXhat = frc::RungeKutta(Dynamics, trueXhat, u, dt);
+ trueXhat = frc::RK4(Dynamics, trueXhat, u, dt);
}
auto localY = LocalMeasurementModel(trueXhat, u);
@@ -162,15 +147,19 @@
auto globalY = GlobalMeasurementModel(trueXhat, u);
auto R = frc::MakeCovMatrix(0.01, 0.01, 0.0001, 0.5, 0.5);
- observer.Correct<5>(u, globalY, GlobalMeasurementModel, R);
+ observer.Correct<5>(u, globalY, GlobalMeasurementModel, R,
+ frc::AngleMean<5, 5>(2), frc::AngleResidual<5>(2),
+ frc::AngleResidual<5>(2), frc::AngleAdd<5>(2)
+
+ );
auto finalPosition = trajectory.Sample(trajectory.TotalTime());
- ASSERT_NEAR(finalPosition.pose.Translation().X().template to<double>(),
- observer.Xhat(0), 1.0);
- ASSERT_NEAR(finalPosition.pose.Translation().Y().template to<double>(),
- observer.Xhat(1), 1.0);
- ASSERT_NEAR(finalPosition.pose.Rotation().Radians().template to<double>(),
- observer.Xhat(2), 1.0);
+ ASSERT_NEAR(finalPosition.pose.Translation().X().value(), observer.Xhat(0),
+ 1.0);
+ ASSERT_NEAR(finalPosition.pose.Translation().Y().value(), observer.Xhat(1),
+ 1.0);
+ ASSERT_NEAR(finalPosition.pose.Rotation().Radians().value(), observer.Xhat(2),
+ 1.0);
ASSERT_NEAR(0.0, observer.Xhat(3), 1.0);
ASSERT_NEAR(0.0, observer.Xhat(4), 1.0);
}
diff --git a/wpimath/src/test/native/cpp/filter/LinearFilterNoiseTest.cpp b/wpimath/src/test/native/cpp/filter/LinearFilterNoiseTest.cpp
new file mode 100644
index 0000000..5ccd829
--- /dev/null
+++ b/wpimath/src/test/native/cpp/filter/LinearFilterNoiseTest.cpp
@@ -0,0 +1,69 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/filter/LinearFilter.h" // NOLINT(build/include_order)
+
+#include <cmath>
+#include <random>
+
+#include <wpi/numbers>
+
+#include "gtest/gtest.h"
+#include "units/time.h"
+
+// Filter constants
+static constexpr auto kFilterStep = 5_ms;
+static constexpr auto kFilterTime = 2_s;
+static constexpr double kSinglePoleIIRTimeConstant = 0.015915;
+static constexpr int32_t kMovAvgTaps = 6;
+
+enum LinearFilterNoiseTestType { kTestSinglePoleIIR, kTestMovAvg };
+
+static double GetData(double t) {
+ return 100.0 * std::sin(2.0 * wpi::numbers::pi * t);
+}
+
+class LinearFilterNoiseTest
+ : public testing::TestWithParam<LinearFilterNoiseTestType> {
+ protected:
+ frc::LinearFilter<double> m_filter = [=] {
+ switch (GetParam()) {
+ case kTestSinglePoleIIR:
+ return frc::LinearFilter<double>::SinglePoleIIR(
+ kSinglePoleIIRTimeConstant, kFilterStep);
+ break;
+ default:
+ return frc::LinearFilter<double>::MovingAverage(kMovAvgTaps);
+ break;
+ }
+ }();
+};
+
+/**
+ * Test if the filter reduces the noise produced by a signal generator
+ */
+TEST_P(LinearFilterNoiseTest, NoiseReduce) {
+ double noiseGenError = 0.0;
+ double filterError = 0.0;
+
+ std::random_device rd;
+ std::mt19937 gen{rd()};
+ std::normal_distribution<double> distr{0.0, 10.0};
+
+ for (auto t = 0_s; t < kFilterTime; t += kFilterStep) {
+ double theory = GetData(t.value());
+ double noise = distr(gen);
+ filterError += std::abs(m_filter.Calculate(theory + noise) - theory);
+ noiseGenError += std::abs(noise - theory);
+ }
+
+ RecordProperty("FilterError", filterError);
+
+ // The filter should have produced values closer to the theory
+ EXPECT_GT(noiseGenError, filterError)
+ << "Filter should have reduced noise accumulation but failed";
+}
+
+INSTANTIATE_TEST_SUITE_P(Tests, LinearFilterNoiseTest,
+ testing::Values(kTestSinglePoleIIR, kTestMovAvg));
diff --git a/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp b/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp
new file mode 100644
index 0000000..bca3f9d
--- /dev/null
+++ b/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp
@@ -0,0 +1,214 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/filter/LinearFilter.h" // NOLINT(build/include_order)
+
+#include <cmath>
+#include <functional>
+#include <memory>
+#include <random>
+
+#include <wpi/numbers>
+
+#include "gtest/gtest.h"
+#include "units/time.h"
+
+// Filter constants
+static constexpr auto kFilterStep = 5_ms;
+static constexpr auto kFilterTime = 2_s;
+static constexpr double kSinglePoleIIRTimeConstant = 0.015915;
+static constexpr double kSinglePoleIIRExpectedOutput = -3.2172003;
+static constexpr double kHighPassTimeConstant = 0.006631;
+static constexpr double kHighPassExpectedOutput = 10.074717;
+static constexpr int32_t kMovAvgTaps = 6;
+static constexpr double kMovAvgExpectedOutput = -10.191644;
+
+enum LinearFilterOutputTestType {
+ kTestSinglePoleIIR,
+ kTestHighPass,
+ kTestMovAvg,
+ kTestPulse
+};
+
+static double GetData(double t) {
+ return 100.0 * std::sin(2.0 * wpi::numbers::pi * t) +
+ 20.0 * std::cos(50.0 * wpi::numbers::pi * t);
+}
+
+static double GetPulseData(double t) {
+ if (std::abs(t - 1.0) < 0.001) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+}
+
+/**
+ * A fixture that includes a consistent data source wrapped in a filter
+ */
+class LinearFilterOutputTest
+ : public testing::TestWithParam<LinearFilterOutputTestType> {
+ protected:
+ frc::LinearFilter<double> m_filter = [=] {
+ switch (GetParam()) {
+ case kTestSinglePoleIIR:
+ return frc::LinearFilter<double>::SinglePoleIIR(
+ kSinglePoleIIRTimeConstant, kFilterStep);
+ break;
+ case kTestHighPass:
+ return frc::LinearFilter<double>::HighPass(kHighPassTimeConstant,
+ kFilterStep);
+ break;
+ case kTestMovAvg:
+ return frc::LinearFilter<double>::MovingAverage(kMovAvgTaps);
+ break;
+ default:
+ return frc::LinearFilter<double>::MovingAverage(kMovAvgTaps);
+ break;
+ }
+ }();
+ std::function<double(double)> m_data;
+ double m_expectedOutput = 0.0;
+
+ LinearFilterOutputTest() {
+ switch (GetParam()) {
+ case kTestSinglePoleIIR: {
+ m_data = GetData;
+ m_expectedOutput = kSinglePoleIIRExpectedOutput;
+ break;
+ }
+
+ case kTestHighPass: {
+ m_data = GetData;
+ m_expectedOutput = kHighPassExpectedOutput;
+ break;
+ }
+
+ case kTestMovAvg: {
+ m_data = GetData;
+ m_expectedOutput = kMovAvgExpectedOutput;
+ break;
+ }
+
+ case kTestPulse: {
+ m_data = GetPulseData;
+ m_expectedOutput = 0.0;
+ break;
+ }
+ }
+ }
+};
+
+/**
+ * Test if the linear filters produce consistent output for a given data set.
+ */
+TEST_P(LinearFilterOutputTest, Output) {
+ double filterOutput = 0.0;
+ for (auto t = 0_s; t < kFilterTime; t += kFilterStep) {
+ filterOutput = m_filter.Calculate(m_data(t.value()));
+ }
+
+ RecordProperty("LinearFilterOutput", filterOutput);
+
+ EXPECT_FLOAT_EQ(m_expectedOutput, filterOutput)
+ << "Filter output didn't match expected value";
+}
+
+INSTANTIATE_TEST_SUITE_P(Tests, LinearFilterOutputTest,
+ testing::Values(kTestSinglePoleIIR, kTestHighPass,
+ kTestMovAvg, kTestPulse));
+
+template <int Derivative, int Samples, typename F, typename DfDx>
+void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
+ double max) {
+ auto filter =
+ frc::LinearFilter<double>::BackwardFiniteDifference<Derivative, Samples>(
+ h);
+
+ for (int i = min / h.value(); i < max / h.value(); ++i) {
+ // Let filter initialize
+ if (i < static_cast<int>(min / h.value()) + Samples) {
+ filter.Calculate(f(i * h.value()));
+ continue;
+ }
+
+ // The order of accuracy is O(h^(N - d)) where N is number of stencil
+ // points and d is order of derivative
+ EXPECT_NEAR(dfdx(i * h.value()), filter.Calculate(f(i * h.value())),
+ 10.0 * std::pow(h.value(), Samples - Derivative));
+ }
+}
+
+/**
+ * Test backward finite difference.
+ */
+TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
+ constexpr auto h = 5_ms;
+
+ AssertResults<1, 2>(
+ [](double x) {
+ // f(x) = x²
+ return x * x;
+ },
+ [](double x) {
+ // df/dx = 2x
+ return 2.0 * x;
+ },
+ h, -20.0, 20.0);
+
+ AssertResults<1, 2>(
+ [](double x) {
+ // f(x) = std::sin(x)
+ return std::sin(x);
+ },
+ [](double x) {
+ // df/dx = std::cos(x)
+ return std::cos(x);
+ },
+ h, -20.0, 20.0);
+
+ AssertResults<1, 2>(
+ [](double x) {
+ // f(x) = ln(x)
+ return std::log(x);
+ },
+ [](double x) {
+ // df/dx = 1 / x
+ return 1.0 / x;
+ },
+ h, 1.0, 20.0);
+
+ AssertResults<2, 4>(
+ [](double x) {
+ // f(x) = x^2
+ return x * x;
+ },
+ [](double x) {
+ // d²f/dx² = 2
+ return 2.0;
+ },
+ h, -20.0, 20.0);
+
+ AssertResults<2, 4>(
+ [](double x) {
+ // f(x) = std::sin(x)
+ return std::sin(x);
+ },
+ [](double x) {
+ // d²f/dx² = -std::sin(x)
+ return -std::sin(x);
+ },
+ h, -20.0, 20.0);
+
+ AssertResults<2, 4>(
+ [](double x) {
+ // f(x) = ln(x)
+ return std::log(x);
+ },
+ [](double x) {
+ // d²f/dx² = -1 / x²
+ return -1.0 / (x * x);
+ },
+ h, 1.0, 20.0);
+}
diff --git a/wpimath/src/test/native/cpp/MedianFilterTest.cpp b/wpimath/src/test/native/cpp/filter/MedianFilterTest.cpp
similarity index 65%
rename from wpimath/src/test/native/cpp/MedianFilterTest.cpp
rename to wpimath/src/test/native/cpp/filter/MedianFilterTest.cpp
index 2a02e1c..8151a45 100644
--- a/wpimath/src/test/native/cpp/MedianFilterTest.cpp
+++ b/wpimath/src/test/native/cpp/filter/MedianFilterTest.cpp
@@ -1,11 +1,8 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-#include "frc/MedianFilter.h"
+#include "frc/filter/MedianFilter.h"
#include "gtest/gtest.h"
TEST(MedianFilterTest, MedianFilterNotFullTestEven) {
diff --git a/wpimath/src/test/native/cpp/filter/SlewRateLimiterTest.cpp b/wpimath/src/test/native/cpp/filter/SlewRateLimiterTest.cpp
new file mode 100644
index 0000000..d2c0bae
--- /dev/null
+++ b/wpimath/src/test/native/cpp/filter/SlewRateLimiterTest.cpp
@@ -0,0 +1,33 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <wpi/timestamp.h>
+
+#include "frc/filter/SlewRateLimiter.h"
+#include "gtest/gtest.h"
+#include "units/length.h"
+#include "units/time.h"
+#include "units/velocity.h"
+
+static units::second_t now = 0_s;
+
+TEST(SlewRateLimiterTest, SlewRateLimit) {
+ WPI_SetNowImpl([] { return units::microsecond_t{now}.to<uint64_t>(); });
+
+ frc::SlewRateLimiter<units::meters> limiter(1_mps);
+
+ now += 1_s;
+
+ EXPECT_LT(limiter.Calculate(2_m), 2_m);
+}
+
+TEST(SlewRateLimiterTest, SlewRateNoLimit) {
+ WPI_SetNowImpl([] { return units::microsecond_t{now}.to<uint64_t>(); });
+
+ frc::SlewRateLimiter<units::meters> limiter(1_mps);
+
+ now += 1_s;
+
+ EXPECT_EQ(limiter.Calculate(0.5_m), 0.5_m);
+}
diff --git a/wpimath/src/test/native/cpp/geometry/Pose2dTest.cpp b/wpimath/src/test/native/cpp/geometry/Pose2dTest.cpp
index 620c9d3..cd5b127 100644
--- a/wpimath/src/test/native/cpp/geometry/Pose2dTest.cpp
+++ b/wpimath/src/test/native/cpp/geometry/Pose2dTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <cmath>
@@ -20,9 +17,9 @@
const auto transformed = initial + transform;
- EXPECT_NEAR(transformed.X().to<double>(), 1 + 5 / std::sqrt(2.0), kEpsilon);
- EXPECT_NEAR(transformed.Y().to<double>(), 2 + 5 / std::sqrt(2.0), kEpsilon);
- EXPECT_NEAR(transformed.Rotation().Degrees().to<double>(), 50.0, kEpsilon);
+ EXPECT_NEAR(transformed.X().value(), 1 + 5 / std::sqrt(2.0), kEpsilon);
+ EXPECT_NEAR(transformed.Y().value(), 2 + 5 / std::sqrt(2.0), kEpsilon);
+ EXPECT_NEAR(transformed.Rotation().Degrees().value(), 50.0, kEpsilon);
}
TEST(Pose2dTest, RelativeTo) {
@@ -31,10 +28,10 @@
const auto finalRelativeToInitial = final.RelativeTo(initial);
- EXPECT_NEAR(finalRelativeToInitial.X().to<double>(), 5.0 * std::sqrt(2.0),
+ EXPECT_NEAR(finalRelativeToInitial.X().value(), 5.0 * std::sqrt(2.0),
kEpsilon);
- EXPECT_NEAR(finalRelativeToInitial.Y().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(finalRelativeToInitial.Rotation().Degrees().to<double>(), 0.0,
+ EXPECT_NEAR(finalRelativeToInitial.Y().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(finalRelativeToInitial.Rotation().Degrees().value(), 0.0,
kEpsilon);
}
@@ -56,7 +53,7 @@
const auto transform = final - initial;
- EXPECT_NEAR(transform.X().to<double>(), 5.0 * std::sqrt(2.0), kEpsilon);
- EXPECT_NEAR(transform.Y().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(transform.Rotation().Degrees().to<double>(), 0.0, kEpsilon);
+ EXPECT_NEAR(transform.X().value(), 5.0 * std::sqrt(2.0), kEpsilon);
+ EXPECT_NEAR(transform.Y().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(transform.Rotation().Degrees().value(), 0.0, kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/geometry/Rotation2dTest.cpp b/wpimath/src/test/native/cpp/geometry/Rotation2dTest.cpp
index a29371f..ed3b6b5 100644
--- a/wpimath/src/test/native/cpp/geometry/Rotation2dTest.cpp
+++ b/wpimath/src/test/native/cpp/geometry/Rotation2dTest.cpp
@@ -1,13 +1,10 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <cmath>
-#include <wpi/math>
+#include <wpi/numbers>
#include "frc/geometry/Rotation2d.h"
#include "gtest/gtest.h"
@@ -17,51 +14,55 @@
static constexpr double kEpsilon = 1E-9;
TEST(Rotation2dTest, RadiansToDegrees) {
- const Rotation2d one{units::radian_t(wpi::math::pi / 3)};
- const Rotation2d two{units::radian_t(wpi::math::pi / 4)};
+ const Rotation2d rot1{units::radian_t(wpi::numbers::pi / 3)};
+ const Rotation2d rot2{units::radian_t(wpi::numbers::pi / 4)};
- EXPECT_NEAR(one.Degrees().to<double>(), 60.0, kEpsilon);
- EXPECT_NEAR(two.Degrees().to<double>(), 45.0, kEpsilon);
+ EXPECT_NEAR(rot1.Degrees().value(), 60.0, kEpsilon);
+ EXPECT_NEAR(rot2.Degrees().value(), 45.0, kEpsilon);
}
TEST(Rotation2dTest, DegreesToRadians) {
- const auto one = Rotation2d(45.0_deg);
- const auto two = Rotation2d(30.0_deg);
+ const auto rot1 = Rotation2d(45.0_deg);
+ const auto rot2 = Rotation2d(30.0_deg);
- EXPECT_NEAR(one.Radians().to<double>(), wpi::math::pi / 4.0, kEpsilon);
- EXPECT_NEAR(two.Radians().to<double>(), wpi::math::pi / 6.0, kEpsilon);
+ EXPECT_NEAR(rot1.Radians().value(), wpi::numbers::pi / 4.0, kEpsilon);
+ EXPECT_NEAR(rot2.Radians().value(), wpi::numbers::pi / 6.0, kEpsilon);
}
TEST(Rotation2dTest, RotateByFromZero) {
const Rotation2d zero;
auto sum = zero + Rotation2d(90.0_deg);
- EXPECT_NEAR(sum.Radians().to<double>(), wpi::math::pi / 2.0, kEpsilon);
- EXPECT_NEAR(sum.Degrees().to<double>(), 90.0, kEpsilon);
+ EXPECT_NEAR(sum.Radians().value(), wpi::numbers::pi / 2.0, kEpsilon);
+ EXPECT_NEAR(sum.Degrees().value(), 90.0, kEpsilon);
}
TEST(Rotation2dTest, RotateByNonZero) {
auto rot = Rotation2d(90.0_deg);
- rot += Rotation2d(30.0_deg);
+ rot = rot + Rotation2d(30.0_deg);
- EXPECT_NEAR(rot.Degrees().to<double>(), 120.0, kEpsilon);
+ EXPECT_NEAR(rot.Degrees().value(), 120.0, kEpsilon);
}
TEST(Rotation2dTest, Minus) {
- const auto one = Rotation2d(70.0_deg);
- const auto two = Rotation2d(30.0_deg);
+ const auto rot1 = Rotation2d(70.0_deg);
+ const auto rot2 = Rotation2d(30.0_deg);
- EXPECT_NEAR((one - two).Degrees().to<double>(), 40.0, kEpsilon);
+ EXPECT_NEAR((rot1 - rot2).Degrees().value(), 40.0, kEpsilon);
}
TEST(Rotation2dTest, Equality) {
- const auto one = Rotation2d(43_deg);
- const auto two = Rotation2d(43_deg);
- EXPECT_TRUE(one == two);
+ const auto rot1 = Rotation2d(43_deg);
+ const auto rot2 = Rotation2d(43_deg);
+ EXPECT_EQ(rot1, rot2);
+
+ const auto rot3 = Rotation2d(-180_deg);
+ const auto rot4 = Rotation2d(180_deg);
+ EXPECT_EQ(rot3, rot4);
}
TEST(Rotation2dTest, Inequality) {
- const auto one = Rotation2d(43_deg);
- const auto two = Rotation2d(43.5_deg);
- EXPECT_TRUE(one != two);
+ const auto rot1 = Rotation2d(43_deg);
+ const auto rot2 = Rotation2d(43.5_deg);
+ EXPECT_NE(rot1, rot2);
}
diff --git a/wpimath/src/test/native/cpp/geometry/Transform2dTest.cpp b/wpimath/src/test/native/cpp/geometry/Transform2dTest.cpp
index b302fad..968ab29 100644
--- a/wpimath/src/test/native/cpp/geometry/Transform2dTest.cpp
+++ b/wpimath/src/test/native/cpp/geometry/Transform2dTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <cmath>
@@ -18,16 +15,30 @@
static constexpr double kEpsilon = 1E-9;
TEST(Transform2dTest, Inverse) {
- const Pose2d initial{1_m, 2_m, Rotation2d(45.0_deg)};
- const Transform2d transform{Translation2d{5.0_m, 0.0_m}, Rotation2d(5.0_deg)};
+ const Pose2d initial{1_m, 2_m, 45_deg};
+ const Transform2d transform{{5_m, 0_m}, 5_deg};
auto transformed = initial + transform;
auto untransformed = transformed + transform.Inverse();
- EXPECT_NEAR(initial.X().to<double>(), untransformed.X().to<double>(),
+ EXPECT_NEAR(initial.X().value(), untransformed.X().value(), kEpsilon);
+ EXPECT_NEAR(initial.Y().value(), untransformed.Y().value(), kEpsilon);
+ EXPECT_NEAR(initial.Rotation().Degrees().value(),
+ untransformed.Rotation().Degrees().value(), kEpsilon);
+}
+
+TEST(Transform2dTest, Composition) {
+ const Pose2d initial{1_m, 2_m, 45_deg};
+ const Transform2d transform1{{5_m, 0_m}, 5_deg};
+ const Transform2d transform2{{0_m, 2_m}, 5_deg};
+
+ auto transformedSeparate = initial + transform1 + transform2;
+ auto transformedCombined = initial + (transform1 + transform2);
+
+ EXPECT_NEAR(transformedSeparate.X().value(), transformedCombined.X().value(),
kEpsilon);
- EXPECT_NEAR(initial.Y().to<double>(), untransformed.Y().to<double>(),
+ EXPECT_NEAR(transformedSeparate.Y().value(), transformedCombined.Y().value(),
kEpsilon);
- EXPECT_NEAR(initial.Rotation().Degrees().to<double>(),
- untransformed.Rotation().Degrees().to<double>(), kEpsilon);
+ EXPECT_NEAR(transformedSeparate.Rotation().Degrees().value(),
+ transformedCombined.Rotation().Degrees().value(), kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/geometry/Translation2dTest.cpp b/wpimath/src/test/native/cpp/geometry/Translation2dTest.cpp
index 8e487f3..efdcace 100644
--- a/wpimath/src/test/native/cpp/geometry/Translation2dTest.cpp
+++ b/wpimath/src/test/native/cpp/geometry/Translation2dTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <cmath>
@@ -20,8 +17,8 @@
const auto sum = one + two;
- EXPECT_NEAR(sum.X().to<double>(), 3.0, kEpsilon);
- EXPECT_NEAR(sum.Y().to<double>(), 8.0, kEpsilon);
+ EXPECT_NEAR(sum.X().value(), 3.0, kEpsilon);
+ EXPECT_NEAR(sum.Y().value(), 8.0, kEpsilon);
}
TEST(Translation2dTest, Difference) {
@@ -30,51 +27,51 @@
const auto difference = one - two;
- EXPECT_NEAR(difference.X().to<double>(), -1.0, kEpsilon);
- EXPECT_NEAR(difference.Y().to<double>(), -2.0, kEpsilon);
+ EXPECT_NEAR(difference.X().value(), -1.0, kEpsilon);
+ EXPECT_NEAR(difference.Y().value(), -2.0, kEpsilon);
}
TEST(Translation2dTest, RotateBy) {
const Translation2d another{3.0_m, 0.0_m};
const auto rotated = another.RotateBy(Rotation2d(90.0_deg));
- EXPECT_NEAR(rotated.X().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(rotated.Y().to<double>(), 3.0, kEpsilon);
+ EXPECT_NEAR(rotated.X().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(rotated.Y().value(), 3.0, kEpsilon);
}
TEST(Translation2dTest, Multiplication) {
const Translation2d original{3.0_m, 5.0_m};
const auto mult = original * 3;
- EXPECT_NEAR(mult.X().to<double>(), 9.0, kEpsilon);
- EXPECT_NEAR(mult.Y().to<double>(), 15.0, kEpsilon);
+ EXPECT_NEAR(mult.X().value(), 9.0, kEpsilon);
+ EXPECT_NEAR(mult.Y().value(), 15.0, kEpsilon);
}
-TEST(Translation2d, Division) {
+TEST(Translation2dTest, Division) {
const Translation2d original{3.0_m, 5.0_m};
const auto div = original / 2;
- EXPECT_NEAR(div.X().to<double>(), 1.5, kEpsilon);
- EXPECT_NEAR(div.Y().to<double>(), 2.5, kEpsilon);
+ EXPECT_NEAR(div.X().value(), 1.5, kEpsilon);
+ EXPECT_NEAR(div.Y().value(), 2.5, kEpsilon);
}
TEST(Translation2dTest, Norm) {
const Translation2d one{3.0_m, 5.0_m};
- EXPECT_NEAR(one.Norm().to<double>(), std::hypot(3, 5), kEpsilon);
+ EXPECT_NEAR(one.Norm().value(), std::hypot(3, 5), kEpsilon);
}
TEST(Translation2dTest, Distance) {
const Translation2d one{1_m, 1_m};
const Translation2d two{6_m, 6_m};
- EXPECT_NEAR(one.Distance(two).to<double>(), 5 * std::sqrt(2), kEpsilon);
+ EXPECT_NEAR(one.Distance(two).value(), 5 * std::sqrt(2), kEpsilon);
}
TEST(Translation2dTest, UnaryMinus) {
const Translation2d original{-4.5_m, 7_m};
const auto inverted = -original;
- EXPECT_NEAR(inverted.X().to<double>(), 4.5, kEpsilon);
- EXPECT_NEAR(inverted.Y().to<double>(), -7, kEpsilon);
+ EXPECT_NEAR(inverted.X().value(), 4.5, kEpsilon);
+ EXPECT_NEAR(inverted.Y().value(), -7, kEpsilon);
}
TEST(Translation2dTest, Equality) {
@@ -91,10 +88,10 @@
TEST(Translation2dTest, PolarConstructor) {
Translation2d one{std::sqrt(2) * 1_m, Rotation2d(45_deg)};
- EXPECT_NEAR(one.X().to<double>(), 1.0, kEpsilon);
- EXPECT_NEAR(one.Y().to<double>(), 1.0, kEpsilon);
+ EXPECT_NEAR(one.X().value(), 1.0, kEpsilon);
+ EXPECT_NEAR(one.Y().value(), 1.0, kEpsilon);
Translation2d two{2_m, Rotation2d(60_deg)};
- EXPECT_NEAR(two.X().to<double>(), 1.0, kEpsilon);
- EXPECT_NEAR(two.Y().to<double>(), std::sqrt(3.0), kEpsilon);
+ EXPECT_NEAR(two.X().value(), 1.0, kEpsilon);
+ EXPECT_NEAR(two.Y().value(), std::sqrt(3.0), kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/geometry/Twist2dTest.cpp b/wpimath/src/test/native/cpp/geometry/Twist2dTest.cpp
index 4766bd4..fa9eecc 100644
--- a/wpimath/src/test/native/cpp/geometry/Twist2dTest.cpp
+++ b/wpimath/src/test/native/cpp/geometry/Twist2dTest.cpp
@@ -1,13 +1,10 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <cmath>
-#include <wpi/math>
+#include <wpi/numbers>
#include "frc/geometry/Pose2d.h"
#include "gtest/gtest.h"
@@ -20,29 +17,28 @@
const Twist2d straight{5.0_m, 0.0_m, 0.0_rad};
const auto straightPose = Pose2d().Exp(straight);
- EXPECT_NEAR(straightPose.X().to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(straightPose.Y().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(straightPose.Rotation().Radians().to<double>(), 0.0, kEpsilon);
+ EXPECT_NEAR(straightPose.X().value(), 5.0, kEpsilon);
+ EXPECT_NEAR(straightPose.Y().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(straightPose.Rotation().Radians().value(), 0.0, kEpsilon);
}
TEST(Twist2dTest, QuarterCircle) {
- const Twist2d quarterCircle{5.0_m / 2.0 * wpi::math::pi, 0_m,
- units::radian_t(wpi::math::pi / 2.0)};
+ const Twist2d quarterCircle{5.0_m / 2.0 * wpi::numbers::pi, 0_m,
+ units::radian_t(wpi::numbers::pi / 2.0)};
const auto quarterCirclePose = Pose2d().Exp(quarterCircle);
- EXPECT_NEAR(quarterCirclePose.X().to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(quarterCirclePose.Y().to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(quarterCirclePose.Rotation().Degrees().to<double>(), 90.0,
- kEpsilon);
+ EXPECT_NEAR(quarterCirclePose.X().value(), 5.0, kEpsilon);
+ EXPECT_NEAR(quarterCirclePose.Y().value(), 5.0, kEpsilon);
+ EXPECT_NEAR(quarterCirclePose.Rotation().Degrees().value(), 90.0, kEpsilon);
}
TEST(Twist2dTest, DiagonalNoDtheta) {
const Twist2d diagonal{2.0_m, 2.0_m, 0.0_deg};
const auto diagonalPose = Pose2d().Exp(diagonal);
- EXPECT_NEAR(diagonalPose.X().to<double>(), 2.0, kEpsilon);
- EXPECT_NEAR(diagonalPose.Y().to<double>(), 2.0, kEpsilon);
- EXPECT_NEAR(diagonalPose.Rotation().Degrees().to<double>(), 0.0, kEpsilon);
+ EXPECT_NEAR(diagonalPose.X().value(), 2.0, kEpsilon);
+ EXPECT_NEAR(diagonalPose.Y().value(), 2.0, kEpsilon);
+ EXPECT_NEAR(diagonalPose.Rotation().Degrees().value(), 0.0, kEpsilon);
}
TEST(Twist2dTest, Equality) {
@@ -63,7 +59,7 @@
const auto twist = start.Log(end);
- EXPECT_NEAR(twist.dx.to<double>(), 5 / 2.0 * wpi::math::pi, kEpsilon);
- EXPECT_NEAR(twist.dy.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(twist.dtheta.to<double>(), wpi::math::pi / 2.0, kEpsilon);
+ EXPECT_NEAR(twist.dx.value(), 5 / 2.0 * wpi::numbers::pi, kEpsilon);
+ EXPECT_NEAR(twist.dy.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(twist.dtheta.value(), wpi::numbers::pi / 2.0, kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/ChassisSpeedsTest.cpp b/wpimath/src/test/native/cpp/kinematics/ChassisSpeedsTest.cpp
index 864860a..7665a97 100644
--- a/wpimath/src/test/native/cpp/kinematics/ChassisSpeedsTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/ChassisSpeedsTest.cpp
@@ -1,20 +1,17 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/ChassisSpeeds.h"
#include "gtest/gtest.h"
static constexpr double kEpsilon = 1E-9;
-TEST(ChassisSpeeds, FieldRelativeConstruction) {
+TEST(ChassisSpeedsTest, FieldRelativeConstruction) {
const auto chassisSpeeds = frc::ChassisSpeeds::FromFieldRelativeSpeeds(
1.0_mps, 0.0_mps, 0.5_rad_per_s, frc::Rotation2d(-90.0_deg));
- EXPECT_NEAR(0.0, chassisSpeeds.vx.to<double>(), kEpsilon);
- EXPECT_NEAR(1.0, chassisSpeeds.vy.to<double>(), kEpsilon);
- EXPECT_NEAR(0.5, chassisSpeeds.omega.to<double>(), kEpsilon);
+ EXPECT_NEAR(0.0, chassisSpeeds.vx.value(), kEpsilon);
+ EXPECT_NEAR(1.0, chassisSpeeds.vy.value(), kEpsilon);
+ EXPECT_NEAR(0.5, chassisSpeeds.omega.value(), kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/DifferentialDriveKinematicsTest.cpp b/wpimath/src/test/native/cpp/kinematics/DifferentialDriveKinematicsTest.cpp
index 7c1a28d..224e231 100644
--- a/wpimath/src/test/native/cpp/kinematics/DifferentialDriveKinematicsTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/DifferentialDriveKinematicsTest.cpp
@@ -1,11 +1,8 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-#include <wpi/math>
+#include <wpi/numbers>
#include "frc/kinematics/ChassisSpeeds.h"
#include "frc/kinematics/DifferentialDriveKinematics.h"
@@ -18,62 +15,62 @@
static constexpr double kEpsilon = 1E-9;
-TEST(DifferentialDriveKinematics, InverseKinematicsFromZero) {
+TEST(DifferentialDriveKinematicsTest, InverseKinematicsFromZero) {
const DifferentialDriveKinematics kinematics{0.381_m * 2};
const ChassisSpeeds chassisSpeeds;
const auto wheelSpeeds = kinematics.ToWheelSpeeds(chassisSpeeds);
- EXPECT_NEAR(wheelSpeeds.left.to<double>(), 0, kEpsilon);
- EXPECT_NEAR(wheelSpeeds.right.to<double>(), 0, kEpsilon);
+ EXPECT_NEAR(wheelSpeeds.left.value(), 0, kEpsilon);
+ EXPECT_NEAR(wheelSpeeds.right.value(), 0, kEpsilon);
}
-TEST(DifferentialDriveKinematics, ForwardKinematicsFromZero) {
+TEST(DifferentialDriveKinematicsTest, ForwardKinematicsFromZero) {
const DifferentialDriveKinematics kinematics{0.381_m * 2};
const DifferentialDriveWheelSpeeds wheelSpeeds;
const auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), 0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 0, kEpsilon);
}
-TEST(DifferentialDriveKinematics, InverseKinematicsForStraightLine) {
+TEST(DifferentialDriveKinematicsTest, InverseKinematicsForStraightLine) {
const DifferentialDriveKinematics kinematics{0.381_m * 2};
const ChassisSpeeds chassisSpeeds{3.0_mps, 0_mps, 0_rad_per_s};
const auto wheelSpeeds = kinematics.ToWheelSpeeds(chassisSpeeds);
- EXPECT_NEAR(wheelSpeeds.left.to<double>(), 3, kEpsilon);
- EXPECT_NEAR(wheelSpeeds.right.to<double>(), 3, kEpsilon);
+ EXPECT_NEAR(wheelSpeeds.left.value(), 3, kEpsilon);
+ EXPECT_NEAR(wheelSpeeds.right.value(), 3, kEpsilon);
}
-TEST(DifferentialDriveKinematics, ForwardKinematicsForStraightLine) {
+TEST(DifferentialDriveKinematicsTest, ForwardKinematicsForStraightLine) {
const DifferentialDriveKinematics kinematics{0.381_m * 2};
const DifferentialDriveWheelSpeeds wheelSpeeds{3.0_mps, 3.0_mps};
const auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 3, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), 0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 3, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 0, kEpsilon);
}
-TEST(DifferentialDriveKinematics, InverseKinematicsForRotateInPlace) {
+TEST(DifferentialDriveKinematicsTest, InverseKinematicsForRotateInPlace) {
const DifferentialDriveKinematics kinematics{0.381_m * 2};
- const ChassisSpeeds chassisSpeeds{0.0_mps, 0.0_mps,
- units::radians_per_second_t{wpi::math::pi}};
+ const ChassisSpeeds chassisSpeeds{
+ 0.0_mps, 0.0_mps, units::radians_per_second_t{wpi::numbers::pi}};
const auto wheelSpeeds = kinematics.ToWheelSpeeds(chassisSpeeds);
- EXPECT_NEAR(wheelSpeeds.left.to<double>(), -0.381 * wpi::math::pi, kEpsilon);
- EXPECT_NEAR(wheelSpeeds.right.to<double>(), +0.381 * wpi::math::pi, kEpsilon);
+ EXPECT_NEAR(wheelSpeeds.left.value(), -0.381 * wpi::numbers::pi, kEpsilon);
+ EXPECT_NEAR(wheelSpeeds.right.value(), +0.381 * wpi::numbers::pi, kEpsilon);
}
-TEST(DifferentialDriveKinematics, ForwardKinematicsForRotateInPlace) {
+TEST(DifferentialDriveKinematicsTest, ForwardKinematicsForRotateInPlace) {
const DifferentialDriveKinematics kinematics{0.381_m * 2};
const DifferentialDriveWheelSpeeds wheelSpeeds{
- units::meters_per_second_t(+0.381 * wpi::math::pi),
- units::meters_per_second_t(-0.381 * wpi::math::pi)};
+ units::meters_per_second_t(+0.381 * wpi::numbers::pi),
+ units::meters_per_second_t(-0.381 * wpi::numbers::pi)};
const auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), 0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), -wpi::math::pi, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), 0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), -wpi::numbers::pi, kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/DifferentialDriveOdometryTest.cpp b/wpimath/src/test/native/cpp/kinematics/DifferentialDriveOdometryTest.cpp
index 89d65ea..da16b28 100644
--- a/wpimath/src/test/native/cpp/kinematics/DifferentialDriveOdometryTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/DifferentialDriveOdometryTest.cpp
@@ -1,11 +1,8 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-#include <wpi/math>
+#include <wpi/numbers>
#include "frc/kinematics/DifferentialDriveKinematics.h"
#include "frc/kinematics/DifferentialDriveOdometry.h"
@@ -15,13 +12,13 @@
using namespace frc;
-TEST(DifferentialDriveOdometry, EncoderDistances) {
+TEST(DifferentialDriveOdometryTest, EncoderDistances) {
DifferentialDriveOdometry odometry{Rotation2d(45_deg)};
const auto& pose = odometry.Update(Rotation2d(135_deg), 0_m,
- units::meter_t(5 * wpi::math::pi));
+ units::meter_t(5 * wpi::numbers::pi));
- EXPECT_NEAR(pose.X().to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(pose.Y().to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(pose.Rotation().Degrees().to<double>(), 90.0, kEpsilon);
+ EXPECT_NEAR(pose.X().value(), 5.0, kEpsilon);
+ EXPECT_NEAR(pose.Y().value(), 5.0, kEpsilon);
+ EXPECT_NEAR(pose.Rotation().Degrees().value(), 90.0, kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/MecanumDriveKinematicsTest.cpp b/wpimath/src/test/native/cpp/kinematics/MecanumDriveKinematicsTest.cpp
index cb03d97..18b72de 100644
--- a/wpimath/src/test/native/cpp/kinematics/MecanumDriveKinematicsTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/MecanumDriveKinematicsTest.cpp
@@ -1,11 +1,8 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-#include <wpi/math>
+#include <wpi/numbers>
#include "frc/geometry/Translation2d.h"
#include "frc/kinematics/MecanumDriveKinematics.h"
@@ -28,113 +25,69 @@
ChassisSpeeds speeds{5_mps, 0_mps, 0_rad_per_s};
auto moduleStates = kinematics.ToWheelSpeeds(speeds);
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl 3.535534 fr 3.535534 rl 3.535534 rr 3.535534
- */
-
- EXPECT_NEAR(3.536, moduleStates.frontLeft.to<double>(), 0.1);
- EXPECT_NEAR(3.536, moduleStates.frontRight.to<double>(), 0.1);
- EXPECT_NEAR(3.536, moduleStates.rearLeft.to<double>(), 0.1);
- EXPECT_NEAR(3.536, moduleStates.rearRight.to<double>(), 0.1);
+ EXPECT_NEAR(5.0, moduleStates.frontLeft.value(), 0.1);
+ EXPECT_NEAR(5.0, moduleStates.frontRight.value(), 0.1);
+ EXPECT_NEAR(5.0, moduleStates.rearLeft.value(), 0.1);
+ EXPECT_NEAR(5.0, moduleStates.rearRight.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, StraightLineForwardKinematics) {
- MecanumDriveWheelSpeeds wheelSpeeds{3.536_mps, 3.536_mps, 3.536_mps,
- 3.536_mps};
+ MecanumDriveWheelSpeeds wheelSpeeds{5_mps, 5_mps, 5_mps, 5_mps};
auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl 3.535534 fr 3.535534 rl 3.535534 rr 3.535534 will be
- [[5][0][0]]
- */
-
- EXPECT_NEAR(5.0, chassisSpeeds.vx.to<double>(), 0.1);
- EXPECT_NEAR(0.0, chassisSpeeds.vy.to<double>(), 0.1);
- EXPECT_NEAR(0.0, chassisSpeeds.omega.to<double>(), 0.1);
+ EXPECT_NEAR(5.0, chassisSpeeds.vx.value(), 0.1);
+ EXPECT_NEAR(0.0, chassisSpeeds.vy.value(), 0.1);
+ EXPECT_NEAR(0.0, chassisSpeeds.omega.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, StrafeInverseKinematics) {
ChassisSpeeds speeds{0_mps, 4_mps, 0_rad_per_s};
auto moduleStates = kinematics.ToWheelSpeeds(speeds);
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl -2.828427 fr 2.828427 rl 2.828427 rr -2.828427
- */
-
- EXPECT_NEAR(-2.828427, moduleStates.frontLeft.to<double>(), 0.1);
- EXPECT_NEAR(2.828427, moduleStates.frontRight.to<double>(), 0.1);
- EXPECT_NEAR(2.828427, moduleStates.rearLeft.to<double>(), 0.1);
- EXPECT_NEAR(-2.828427, moduleStates.rearRight.to<double>(), 0.1);
+ EXPECT_NEAR(-4.0, moduleStates.frontLeft.value(), 0.1);
+ EXPECT_NEAR(4.0, moduleStates.frontRight.value(), 0.1);
+ EXPECT_NEAR(4.0, moduleStates.rearLeft.value(), 0.1);
+ EXPECT_NEAR(-4.0, moduleStates.rearRight.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, StrafeForwardKinematics) {
- MecanumDriveWheelSpeeds wheelSpeeds{-2.828427_mps, 2.828427_mps, 2.828427_mps,
- -2.828427_mps};
+ MecanumDriveWheelSpeeds wheelSpeeds{-5_mps, 5_mps, 5_mps, -5_mps};
auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl 3.535534 fr 3.535534 rl 3.535534 rr 3.535534 will be
- [[5][0][0]]
- */
-
- EXPECT_NEAR(0.0, chassisSpeeds.vx.to<double>(), 0.1);
- EXPECT_NEAR(4.0, chassisSpeeds.vy.to<double>(), 0.1);
- EXPECT_NEAR(0.0, chassisSpeeds.omega.to<double>(), 0.1);
+ EXPECT_NEAR(0.0, chassisSpeeds.vx.value(), 0.1);
+ EXPECT_NEAR(5.0, chassisSpeeds.vy.value(), 0.1);
+ EXPECT_NEAR(0.0, chassisSpeeds.omega.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, RotationInverseKinematics) {
ChassisSpeeds speeds{0_mps, 0_mps,
- units::radians_per_second_t(2 * wpi::math::pi)};
+ units::radians_per_second_t(2 * wpi::numbers::pi)};
auto moduleStates = kinematics.ToWheelSpeeds(speeds);
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl -106.629191 fr 106.629191 rl -106.629191 rr 106.629191
- */
-
- EXPECT_NEAR(-106.62919, moduleStates.frontLeft.to<double>(), 0.1);
- EXPECT_NEAR(106.62919, moduleStates.frontRight.to<double>(), 0.1);
- EXPECT_NEAR(-106.62919, moduleStates.rearLeft.to<double>(), 0.1);
- EXPECT_NEAR(106.62919, moduleStates.rearRight.to<double>(), 0.1);
+ EXPECT_NEAR(-150.79644737, moduleStates.frontLeft.value(), 0.1);
+ EXPECT_NEAR(150.79644737, moduleStates.frontRight.value(), 0.1);
+ EXPECT_NEAR(-150.79644737, moduleStates.rearLeft.value(), 0.1);
+ EXPECT_NEAR(150.79644737, moduleStates.rearRight.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, RotationForwardKinematics) {
- MecanumDriveWheelSpeeds wheelSpeeds{-106.62919_mps, 106.62919_mps,
- -106.62919_mps, 106.62919_mps};
+ MecanumDriveWheelSpeeds wheelSpeeds{-150.79644737_mps, 150.79644737_mps,
+ -150.79644737_mps, 150.79644737_mps};
auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl -106.629191 fr 106.629191 rl -106.629191 rr 106.629191 should
- be [[0][0][2pi]]
- */
-
- EXPECT_NEAR(0.0, chassisSpeeds.vx.to<double>(), 0.1);
- EXPECT_NEAR(0.0, chassisSpeeds.vy.to<double>(), 0.1);
- EXPECT_NEAR(2 * wpi::math::pi, chassisSpeeds.omega.to<double>(), 0.1);
+ EXPECT_NEAR(0.0, chassisSpeeds.vx.value(), 0.1);
+ EXPECT_NEAR(0.0, chassisSpeeds.vy.value(), 0.1);
+ EXPECT_NEAR(2 * wpi::numbers::pi, chassisSpeeds.omega.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, MixedRotationTranslationInverseKinematics) {
ChassisSpeeds speeds{2_mps, 3_mps, 1_rad_per_s};
auto moduleStates = kinematics.ToWheelSpeeds(speeds);
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl -17.677670 fr 20.506097 rl -13.435029 rr 16.263456
- */
-
- EXPECT_NEAR(-17.677670, moduleStates.frontLeft.to<double>(), 0.1);
- EXPECT_NEAR(20.506097, moduleStates.frontRight.to<double>(), 0.1);
- EXPECT_NEAR(-13.435, moduleStates.rearLeft.to<double>(), 0.1);
- EXPECT_NEAR(16.26, moduleStates.rearRight.to<double>(), 0.1);
+ EXPECT_NEAR(-25.0, moduleStates.frontLeft.value(), 0.1);
+ EXPECT_NEAR(29.0, moduleStates.frontRight.value(), 0.1);
+ EXPECT_NEAR(-19.0, moduleStates.rearLeft.value(), 0.1);
+ EXPECT_NEAR(23.0, moduleStates.rearRight.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, MixedRotationTranslationForwardKinematics) {
@@ -143,31 +96,19 @@
auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from wheel
- velocities: fl -17.677670 fr 20.506097 rl -13.435029 rr 16.263456 should be
- [[2][3][1]]
- */
-
- EXPECT_NEAR(2.0, chassisSpeeds.vx.to<double>(), 0.1);
- EXPECT_NEAR(3.0, chassisSpeeds.vy.to<double>(), 0.1);
- EXPECT_NEAR(1.0, chassisSpeeds.omega.to<double>(), 0.1);
+ EXPECT_NEAR(1.41335, chassisSpeeds.vx.value(), 0.1);
+ EXPECT_NEAR(2.1221, chassisSpeeds.vy.value(), 0.1);
+ EXPECT_NEAR(0.707, chassisSpeeds.omega.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, OffCenterRotationInverseKinematics) {
ChassisSpeeds speeds{0_mps, 0_mps, 1_rad_per_s};
auto moduleStates = kinematics.ToWheelSpeeds(speeds, m_fl);
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl 0.000000 fr 16.970563 rl -16.970563 rr 33.941125
- */
-
- EXPECT_NEAR(0, moduleStates.frontLeft.to<double>(), 0.1);
- EXPECT_NEAR(16.971, moduleStates.frontRight.to<double>(), 0.1);
- EXPECT_NEAR(-16.971, moduleStates.rearLeft.to<double>(), 0.1);
- EXPECT_NEAR(33.941, moduleStates.rearRight.to<double>(), 0.1);
+ EXPECT_NEAR(0, moduleStates.frontLeft.value(), 0.1);
+ EXPECT_NEAR(24.0, moduleStates.frontRight.value(), 0.1);
+ EXPECT_NEAR(-24.0, moduleStates.rearLeft.value(), 0.1);
+ EXPECT_NEAR(48.0, moduleStates.rearRight.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest, OffCenterRotationForwardKinematics) {
@@ -175,14 +116,9 @@
33.941_mps};
auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from the
- wheel velocities should be [[12][-12][1]]
- */
-
- EXPECT_NEAR(12.0, chassisSpeeds.vx.to<double>(), 0.1);
- EXPECT_NEAR(-12, chassisSpeeds.vy.to<double>(), 0.1);
- EXPECT_NEAR(1.0, chassisSpeeds.omega.to<double>(), 0.1);
+ EXPECT_NEAR(8.48525, chassisSpeeds.vx.value(), 0.1);
+ EXPECT_NEAR(-8.48525, chassisSpeeds.vy.value(), 0.1);
+ EXPECT_NEAR(0.707, chassisSpeeds.omega.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest,
@@ -190,15 +126,10 @@
ChassisSpeeds speeds{5_mps, 2_mps, 1_rad_per_s};
auto moduleStates = kinematics.ToWheelSpeeds(speeds, m_fl);
- /*
- By equation (13.12) of the state-space-guide, the wheel speeds should
- be as follows:
- velocities: fl 2.121320 fr 21.920310 rl -12.020815 rr 36.062446
- */
- EXPECT_NEAR(2.12, moduleStates.frontLeft.to<double>(), 0.1);
- EXPECT_NEAR(21.92, moduleStates.frontRight.to<double>(), 0.1);
- EXPECT_NEAR(-12.02, moduleStates.rearLeft.to<double>(), 0.1);
- EXPECT_NEAR(36.06, moduleStates.rearRight.to<double>(), 0.1);
+ EXPECT_NEAR(3.0, moduleStates.frontLeft.value(), 0.1);
+ EXPECT_NEAR(31.0, moduleStates.frontRight.value(), 0.1);
+ EXPECT_NEAR(-17.0, moduleStates.rearLeft.value(), 0.1);
+ EXPECT_NEAR(51.0, moduleStates.rearRight.value(), 0.1);
}
TEST_F(MecanumDriveKinematicsTest,
@@ -207,24 +138,19 @@
36.06_mps};
auto chassisSpeeds = kinematics.ToChassisSpeeds(wheelSpeeds);
- /*
- By equation (13.13) of the state-space-guide, the chassis motion from the
- wheel velocities should be [[17][-10][1]]
- */
-
- EXPECT_NEAR(17.0, chassisSpeeds.vx.to<double>(), 0.1);
- EXPECT_NEAR(-10, chassisSpeeds.vy.to<double>(), 0.1);
- EXPECT_NEAR(1.0, chassisSpeeds.omega.to<double>(), 0.1);
+ EXPECT_NEAR(12.02, chassisSpeeds.vx.value(), 0.1);
+ EXPECT_NEAR(-7.07, chassisSpeeds.vy.value(), 0.1);
+ EXPECT_NEAR(0.707, chassisSpeeds.omega.value(), 0.1);
}
-TEST_F(MecanumDriveKinematicsTest, NormalizeTest) {
+TEST_F(MecanumDriveKinematicsTest, Normalize) {
MecanumDriveWheelSpeeds wheelSpeeds{5_mps, 6_mps, 4_mps, 7_mps};
wheelSpeeds.Normalize(5.5_mps);
double kFactor = 5.5 / 7.0;
- EXPECT_NEAR(wheelSpeeds.frontLeft.to<double>(), 5.0 * kFactor, 1E-9);
- EXPECT_NEAR(wheelSpeeds.frontRight.to<double>(), 6.0 * kFactor, 1E-9);
- EXPECT_NEAR(wheelSpeeds.rearLeft.to<double>(), 4.0 * kFactor, 1E-9);
- EXPECT_NEAR(wheelSpeeds.rearRight.to<double>(), 7.0 * kFactor, 1E-9);
+ EXPECT_NEAR(wheelSpeeds.frontLeft.value(), 5.0 * kFactor, 1E-9);
+ EXPECT_NEAR(wheelSpeeds.frontRight.value(), 6.0 * kFactor, 1E-9);
+ EXPECT_NEAR(wheelSpeeds.rearLeft.value(), 4.0 * kFactor, 1E-9);
+ EXPECT_NEAR(wheelSpeeds.rearRight.value(), 7.0 * kFactor, 1E-9);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/MecanumDriveOdometryTest.cpp b/wpimath/src/test/native/cpp/kinematics/MecanumDriveOdometryTest.cpp
index cb85ec7..152506d 100644
--- a/wpimath/src/test/native/cpp/kinematics/MecanumDriveOdometryTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/MecanumDriveOdometryTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/MecanumDriveOdometry.h"
#include "gtest/gtest.h"
@@ -29,9 +26,9 @@
odometry.UpdateWithTime(0_s, Rotation2d(), wheelSpeeds);
auto secondPose = odometry.UpdateWithTime(0.0_s, Rotation2d(), wheelSpeeds);
- EXPECT_NEAR(secondPose.X().to<double>(), 0.0, 0.01);
- EXPECT_NEAR(secondPose.Y().to<double>(), 0.0, 0.01);
- EXPECT_NEAR(secondPose.Rotation().Radians().to<double>(), 0.0, 0.01);
+ EXPECT_NEAR(secondPose.X().value(), 0.0, 0.01);
+ EXPECT_NEAR(secondPose.Y().value(), 0.0, 0.01);
+ EXPECT_NEAR(secondPose.Rotation().Radians().value(), 0.0, 0.01);
}
TEST_F(MecanumDriveOdometryTest, TwoIterations) {
@@ -41,21 +38,21 @@
odometry.UpdateWithTime(0_s, Rotation2d(), MecanumDriveWheelSpeeds{});
auto pose = odometry.UpdateWithTime(0.10_s, Rotation2d(), speeds);
- EXPECT_NEAR(pose.X().to<double>(), 0.5, 0.01);
- EXPECT_NEAR(pose.Y().to<double>(), 0.0, 0.01);
- EXPECT_NEAR(pose.Rotation().Radians().to<double>(), 0.0, 0.01);
+ EXPECT_NEAR(pose.X().value(), 0.3536, 0.01);
+ EXPECT_NEAR(pose.Y().value(), 0.0, 0.01);
+ EXPECT_NEAR(pose.Rotation().Radians().value(), 0.0, 0.01);
}
-TEST_F(MecanumDriveOdometryTest, Test90DegreeTurn) {
+TEST_F(MecanumDriveOdometryTest, 90DegreeTurn) {
odometry.ResetPosition(Pose2d(), 0_rad);
MecanumDriveWheelSpeeds speeds{-13.328_mps, 39.986_mps, -13.329_mps,
39.986_mps};
odometry.UpdateWithTime(0_s, Rotation2d(), MecanumDriveWheelSpeeds{});
auto pose = odometry.UpdateWithTime(1_s, Rotation2d(90_deg), speeds);
- EXPECT_NEAR(pose.X().to<double>(), 12, 0.01);
- EXPECT_NEAR(pose.Y().to<double>(), 12, 0.01);
- EXPECT_NEAR(pose.Rotation().Degrees().to<double>(), 90.0, 0.01);
+ EXPECT_NEAR(pose.X().value(), 8.4855, 0.01);
+ EXPECT_NEAR(pose.Y().value(), 8.4855, 0.01);
+ EXPECT_NEAR(pose.Rotation().Degrees().value(), 90.0, 0.01);
}
TEST_F(MecanumDriveOdometryTest, GyroAngleReset) {
@@ -66,7 +63,7 @@
odometry.UpdateWithTime(0_s, Rotation2d(90_deg), MecanumDriveWheelSpeeds{});
auto pose = odometry.UpdateWithTime(0.10_s, Rotation2d(90_deg), speeds);
- EXPECT_NEAR(pose.X().to<double>(), 0.5, 0.01);
- EXPECT_NEAR(pose.Y().to<double>(), 0.0, 0.01);
- EXPECT_NEAR(pose.Rotation().Radians().to<double>(), 0.0, 0.01);
+ EXPECT_NEAR(pose.X().value(), 0.3536, 0.01);
+ EXPECT_NEAR(pose.Y().value(), 0.0, 0.01);
+ EXPECT_NEAR(pose.Rotation().Radians().value(), 0.0, 0.01);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/SwerveDriveKinematicsTest.cpp b/wpimath/src/test/native/cpp/kinematics/SwerveDriveKinematicsTest.cpp
index 368dbaf..9384d89 100644
--- a/wpimath/src/test/native/cpp/kinematics/SwerveDriveKinematicsTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/SwerveDriveKinematicsTest.cpp
@@ -1,11 +1,8 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
-#include <wpi/math>
+#include <wpi/numbers>
#include "frc/geometry/Translation2d.h"
#include "frc/kinematics/SwerveDriveKinematics.h"
@@ -31,15 +28,15 @@
auto [fl, fr, bl, br] = m_kinematics.ToSwerveModuleStates(speeds);
- EXPECT_NEAR(fl.speed.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(fr.speed.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(bl.speed.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(br.speed.to<double>(), 5.0, kEpsilon);
+ EXPECT_NEAR(fl.speed.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(fr.speed.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(bl.speed.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(br.speed.value(), 5.0, kEpsilon);
- EXPECT_NEAR(fl.angle.Radians().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(fr.angle.Radians().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(bl.angle.Radians().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(br.angle.Radians().to<double>(), 0.0, kEpsilon);
+ EXPECT_NEAR(fl.angle.Radians().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(fr.angle.Radians().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(bl.angle.Radians().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(br.angle.Radians().value(), 0.0, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, StraightLineForwardKinematics) {
@@ -47,49 +44,49 @@
auto chassisSpeeds = m_kinematics.ToChassisSpeeds(state, state, state, state);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 0.0, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, StraightStrafeInverseKinematics) {
ChassisSpeeds speeds{0_mps, 5_mps, 0_rad_per_s};
auto [fl, fr, bl, br] = m_kinematics.ToSwerveModuleStates(speeds);
- EXPECT_NEAR(fl.speed.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(fr.speed.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(bl.speed.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(br.speed.to<double>(), 5.0, kEpsilon);
+ EXPECT_NEAR(fl.speed.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(fr.speed.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(bl.speed.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(br.speed.value(), 5.0, kEpsilon);
- EXPECT_NEAR(fl.angle.Degrees().to<double>(), 90.0, kEpsilon);
- EXPECT_NEAR(fr.angle.Degrees().to<double>(), 90.0, kEpsilon);
- EXPECT_NEAR(bl.angle.Degrees().to<double>(), 90.0, kEpsilon);
- EXPECT_NEAR(br.angle.Degrees().to<double>(), 90.0, kEpsilon);
+ EXPECT_NEAR(fl.angle.Degrees().value(), 90.0, kEpsilon);
+ EXPECT_NEAR(fr.angle.Degrees().value(), 90.0, kEpsilon);
+ EXPECT_NEAR(bl.angle.Degrees().value(), 90.0, kEpsilon);
+ EXPECT_NEAR(br.angle.Degrees().value(), 90.0, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, StraightStrafeForwardKinematics) {
SwerveModuleState state{5_mps, Rotation2d(90_deg)};
auto chassisSpeeds = m_kinematics.ToChassisSpeeds(state, state, state, state);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), 5.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), 5.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 0.0, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, TurnInPlaceInverseKinematics) {
ChassisSpeeds speeds{0_mps, 0_mps,
- units::radians_per_second_t(2 * wpi::math::pi)};
+ units::radians_per_second_t(2 * wpi::numbers::pi)};
auto [fl, fr, bl, br] = m_kinematics.ToSwerveModuleStates(speeds);
- EXPECT_NEAR(fl.speed.to<double>(), 106.63, kEpsilon);
- EXPECT_NEAR(fr.speed.to<double>(), 106.63, kEpsilon);
- EXPECT_NEAR(bl.speed.to<double>(), 106.63, kEpsilon);
- EXPECT_NEAR(br.speed.to<double>(), 106.63, kEpsilon);
+ EXPECT_NEAR(fl.speed.value(), 106.63, kEpsilon);
+ EXPECT_NEAR(fr.speed.value(), 106.63, kEpsilon);
+ EXPECT_NEAR(bl.speed.value(), 106.63, kEpsilon);
+ EXPECT_NEAR(br.speed.value(), 106.63, kEpsilon);
- EXPECT_NEAR(fl.angle.Degrees().to<double>(), 135.0, kEpsilon);
- EXPECT_NEAR(fr.angle.Degrees().to<double>(), 45.0, kEpsilon);
- EXPECT_NEAR(bl.angle.Degrees().to<double>(), -135.0, kEpsilon);
- EXPECT_NEAR(br.angle.Degrees().to<double>(), -45.0, kEpsilon);
+ EXPECT_NEAR(fl.angle.Degrees().value(), 135.0, kEpsilon);
+ EXPECT_NEAR(fr.angle.Degrees().value(), 45.0, kEpsilon);
+ EXPECT_NEAR(bl.angle.Degrees().value(), -135.0, kEpsilon);
+ EXPECT_NEAR(br.angle.Degrees().value(), -45.0, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, TurnInPlaceForwardKinematics) {
@@ -100,25 +97,25 @@
auto chassisSpeeds = m_kinematics.ToChassisSpeeds(fl, fr, bl, br);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 2 * wpi::math::pi, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 2 * wpi::numbers::pi, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, OffCenterCORRotationInverseKinematics) {
ChassisSpeeds speeds{0_mps, 0_mps,
- units::radians_per_second_t(2 * wpi::math::pi)};
+ units::radians_per_second_t(2 * wpi::numbers::pi)};
auto [fl, fr, bl, br] = m_kinematics.ToSwerveModuleStates(speeds, m_fl);
- EXPECT_NEAR(fl.speed.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(fr.speed.to<double>(), 150.796, kEpsilon);
- EXPECT_NEAR(bl.speed.to<double>(), 150.796, kEpsilon);
- EXPECT_NEAR(br.speed.to<double>(), 213.258, kEpsilon);
+ EXPECT_NEAR(fl.speed.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(fr.speed.value(), 150.796, kEpsilon);
+ EXPECT_NEAR(bl.speed.value(), 150.796, kEpsilon);
+ EXPECT_NEAR(br.speed.value(), 213.258, kEpsilon);
- EXPECT_NEAR(fl.angle.Degrees().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(fr.angle.Degrees().to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(bl.angle.Degrees().to<double>(), -90.0, kEpsilon);
- EXPECT_NEAR(br.angle.Degrees().to<double>(), -45.0, kEpsilon);
+ EXPECT_NEAR(fl.angle.Degrees().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(fr.angle.Degrees().value(), 0.0, kEpsilon);
+ EXPECT_NEAR(bl.angle.Degrees().value(), -90.0, kEpsilon);
+ EXPECT_NEAR(br.angle.Degrees().value(), -45.0, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest, OffCenterCORRotationForwardKinematics) {
@@ -129,9 +126,9 @@
auto chassisSpeeds = m_kinematics.ToChassisSpeeds(fl, fr, bl, br);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 75.398, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), -75.398, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 2 * wpi::math::pi, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 75.398, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), -75.398, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 2 * wpi::numbers::pi, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest,
@@ -140,15 +137,15 @@
auto [fl, fr, bl, br] =
m_kinematics.ToSwerveModuleStates(speeds, Translation2d(24_m, 0_m));
- EXPECT_NEAR(fl.speed.to<double>(), 23.43, kEpsilon);
- EXPECT_NEAR(fr.speed.to<double>(), 23.43, kEpsilon);
- EXPECT_NEAR(bl.speed.to<double>(), 54.08, kEpsilon);
- EXPECT_NEAR(br.speed.to<double>(), 54.08, kEpsilon);
+ EXPECT_NEAR(fl.speed.value(), 23.43, kEpsilon);
+ EXPECT_NEAR(fr.speed.value(), 23.43, kEpsilon);
+ EXPECT_NEAR(bl.speed.value(), 54.08, kEpsilon);
+ EXPECT_NEAR(br.speed.value(), 54.08, kEpsilon);
- EXPECT_NEAR(fl.angle.Degrees().to<double>(), -140.19, kEpsilon);
- EXPECT_NEAR(fr.angle.Degrees().to<double>(), -39.81, kEpsilon);
- EXPECT_NEAR(bl.angle.Degrees().to<double>(), -109.44, kEpsilon);
- EXPECT_NEAR(br.angle.Degrees().to<double>(), -70.56, kEpsilon);
+ EXPECT_NEAR(fl.angle.Degrees().value(), -140.19, kEpsilon);
+ EXPECT_NEAR(fr.angle.Degrees().value(), -39.81, kEpsilon);
+ EXPECT_NEAR(bl.angle.Degrees().value(), -109.44, kEpsilon);
+ EXPECT_NEAR(br.angle.Degrees().value(), -70.56, kEpsilon);
}
TEST_F(SwerveDriveKinematicsTest,
@@ -160,24 +157,24 @@
auto chassisSpeeds = m_kinematics.ToChassisSpeeds(fl, fr, bl, br);
- EXPECT_NEAR(chassisSpeeds.vx.to<double>(), 0.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.vy.to<double>(), -33.0, kEpsilon);
- EXPECT_NEAR(chassisSpeeds.omega.to<double>(), 1.5, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vx.value(), 0.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.vy.value(), -33.0, kEpsilon);
+ EXPECT_NEAR(chassisSpeeds.omega.value(), 1.5, kEpsilon);
}
-TEST_F(SwerveDriveKinematicsTest, NormalizeTest) {
+TEST_F(SwerveDriveKinematicsTest, Normalize) {
SwerveModuleState state1{5.0_mps, Rotation2d()};
SwerveModuleState state2{6.0_mps, Rotation2d()};
SwerveModuleState state3{4.0_mps, Rotation2d()};
SwerveModuleState state4{7.0_mps, Rotation2d()};
- std::array<SwerveModuleState, 4> arr{state1, state2, state3, state4};
+ wpi::array<SwerveModuleState, 4> arr{state1, state2, state3, state4};
SwerveDriveKinematics<4>::NormalizeWheelSpeeds(&arr, 5.5_mps);
double kFactor = 5.5 / 7.0;
- EXPECT_NEAR(arr[0].speed.to<double>(), 5.0 * kFactor, kEpsilon);
- EXPECT_NEAR(arr[1].speed.to<double>(), 6.0 * kFactor, kEpsilon);
- EXPECT_NEAR(arr[2].speed.to<double>(), 4.0 * kFactor, kEpsilon);
- EXPECT_NEAR(arr[3].speed.to<double>(), 7.0 * kFactor, kEpsilon);
+ EXPECT_NEAR(arr[0].speed.value(), 5.0 * kFactor, kEpsilon);
+ EXPECT_NEAR(arr[1].speed.value(), 6.0 * kFactor, kEpsilon);
+ EXPECT_NEAR(arr[2].speed.value(), 4.0 * kFactor, kEpsilon);
+ EXPECT_NEAR(arr[3].speed.value(), 7.0 * kFactor, kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/SwerveDriveOdometryTest.cpp b/wpimath/src/test/native/cpp/kinematics/SwerveDriveOdometryTest.cpp
index 40207a1..27d2a6e 100644
--- a/wpimath/src/test/native/cpp/kinematics/SwerveDriveOdometryTest.cpp
+++ b/wpimath/src/test/native/cpp/kinematics/SwerveDriveOdometryTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/kinematics/SwerveDriveKinematics.h"
#include "frc/kinematics/SwerveDriveOdometry.h"
@@ -34,9 +31,9 @@
auto pose = m_odometry.UpdateWithTime(0.1_s, Rotation2d(), state, state,
state, state);
- EXPECT_NEAR(0.5, pose.X().to<double>(), kEpsilon);
- EXPECT_NEAR(0.0, pose.Y().to<double>(), kEpsilon);
- EXPECT_NEAR(0.0, pose.Rotation().Degrees().to<double>(), kEpsilon);
+ EXPECT_NEAR(0.5, pose.X().value(), kEpsilon);
+ EXPECT_NEAR(0.0, pose.Y().value(), kEpsilon);
+ EXPECT_NEAR(0.0, pose.Rotation().Degrees().value(), kEpsilon);
}
TEST_F(SwerveDriveOdometryTest, 90DegreeTurn) {
@@ -52,9 +49,9 @@
auto pose =
m_odometry.UpdateWithTime(1_s, Rotation2d(90_deg), fl, fr, bl, br);
- EXPECT_NEAR(12.0, pose.X().to<double>(), kEpsilon);
- EXPECT_NEAR(12.0, pose.Y().to<double>(), kEpsilon);
- EXPECT_NEAR(90.0, pose.Rotation().Degrees().to<double>(), kEpsilon);
+ EXPECT_NEAR(12.0, pose.X().value(), kEpsilon);
+ EXPECT_NEAR(12.0, pose.Y().value(), kEpsilon);
+ EXPECT_NEAR(90.0, pose.Rotation().Degrees().value(), kEpsilon);
}
TEST_F(SwerveDriveOdometryTest, GyroAngleReset) {
@@ -68,7 +65,7 @@
auto pose = m_odometry.UpdateWithTime(0.1_s, Rotation2d(90_deg), state, state,
state, state);
- EXPECT_NEAR(0.5, pose.X().to<double>(), kEpsilon);
- EXPECT_NEAR(0.0, pose.Y().to<double>(), kEpsilon);
- EXPECT_NEAR(0.0, pose.Rotation().Degrees().to<double>(), kEpsilon);
+ EXPECT_NEAR(0.5, pose.X().value(), kEpsilon);
+ EXPECT_NEAR(0.0, pose.Y().value(), kEpsilon);
+ EXPECT_NEAR(0.0, pose.Rotation().Degrees().value(), kEpsilon);
}
diff --git a/wpimath/src/test/native/cpp/kinematics/SwerveModuleStateTest.cpp b/wpimath/src/test/native/cpp/kinematics/SwerveModuleStateTest.cpp
new file mode 100644
index 0000000..4880bef
--- /dev/null
+++ b/wpimath/src/test/native/cpp/kinematics/SwerveModuleStateTest.cpp
@@ -0,0 +1,41 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/geometry/Rotation2d.h"
+#include "frc/kinematics/SwerveModuleState.h"
+#include "gtest/gtest.h"
+
+static constexpr double kEpsilon = 1E-9;
+
+TEST(SwerveModuleStateTest, Optimize) {
+ frc::Rotation2d angleA{45_deg};
+ frc::SwerveModuleState refA{-2_mps, 180_deg};
+ auto optimizedA = frc::SwerveModuleState::Optimize(refA, angleA);
+
+ EXPECT_NEAR(optimizedA.speed.value(), 2.0, kEpsilon);
+ EXPECT_NEAR(optimizedA.angle.Degrees().value(), 0.0, kEpsilon);
+
+ frc::Rotation2d angleB{-50_deg};
+ frc::SwerveModuleState refB{4.7_mps, 41_deg};
+ auto optimizedB = frc::SwerveModuleState::Optimize(refB, angleB);
+
+ EXPECT_NEAR(optimizedB.speed.value(), -4.7, kEpsilon);
+ EXPECT_NEAR(optimizedB.angle.Degrees().value(), -139.0, kEpsilon);
+}
+
+TEST(SwerveModuleStateTest, NoOptimize) {
+ frc::Rotation2d angleA{0_deg};
+ frc::SwerveModuleState refA{2_mps, 89_deg};
+ auto optimizedA = frc::SwerveModuleState::Optimize(refA, angleA);
+
+ EXPECT_NEAR(optimizedA.speed.value(), 2.0, kEpsilon);
+ EXPECT_NEAR(optimizedA.angle.Degrees().value(), 89.0, kEpsilon);
+
+ frc::Rotation2d angleB{0_deg};
+ frc::SwerveModuleState refB{-2_mps, -2_deg};
+ auto optimizedB = frc::SwerveModuleState::Optimize(refB, angleB);
+
+ EXPECT_NEAR(optimizedB.speed.value(), -2.0, kEpsilon);
+ EXPECT_NEAR(optimizedB.angle.Degrees().value(), -2.0, kEpsilon);
+}
diff --git a/wpimath/src/test/native/cpp/main.cpp b/wpimath/src/test/native/cpp/main.cpp
index e2126f2..09072ee 100644
--- a/wpimath/src/test/native/cpp/main.cpp
+++ b/wpimath/src/test/native/cpp/main.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2015-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "gtest/gtest.h"
diff --git a/wpimath/src/test/native/cpp/spline/CubicHermiteSplineTest.cpp b/wpimath/src/test/native/cpp/spline/CubicHermiteSplineTest.cpp
index fe084e0..69e202f 100644
--- a/wpimath/src/test/native/cpp/spline/CubicHermiteSplineTest.cpp
+++ b/wpimath/src/test/native/cpp/spline/CubicHermiteSplineTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <chrono>
#include <iostream>
@@ -56,27 +53,26 @@
// Make sure the twist is under the tolerance defined by the Spline class.
auto twist = p0.first.Log(p1.first);
- EXPECT_LT(std::abs(twist.dx.to<double>()),
- SplineParameterizer::kMaxDx.to<double>());
- EXPECT_LT(std::abs(twist.dy.to<double>()),
- SplineParameterizer::kMaxDy.to<double>());
- EXPECT_LT(std::abs(twist.dtheta.to<double>()),
- SplineParameterizer::kMaxDtheta.to<double>());
+ EXPECT_LT(std::abs(twist.dx.value()),
+ SplineParameterizer::kMaxDx.value());
+ EXPECT_LT(std::abs(twist.dy.value()),
+ SplineParameterizer::kMaxDy.value());
+ EXPECT_LT(std::abs(twist.dtheta.value()),
+ SplineParameterizer::kMaxDtheta.value());
}
// Check first point.
- EXPECT_NEAR(poses.front().first.X().to<double>(), a.X().to<double>(), 1E-9);
- EXPECT_NEAR(poses.front().first.Y().to<double>(), a.Y().to<double>(), 1E-9);
- EXPECT_NEAR(poses.front().first.Rotation().Radians().to<double>(),
- a.Rotation().Radians().to<double>(), 1E-9);
+ EXPECT_NEAR(poses.front().first.X().value(), a.X().value(), 1E-9);
+ EXPECT_NEAR(poses.front().first.Y().value(), a.Y().value(), 1E-9);
+ EXPECT_NEAR(poses.front().first.Rotation().Radians().value(),
+ a.Rotation().Radians().value(), 1E-9);
// Check interior waypoints
bool interiorsGood = true;
for (auto& waypoint : waypoints) {
bool found = false;
for (auto& state : poses) {
- if (std::abs(
- waypoint.Distance(state.first.Translation()).to<double>()) <
+ if (std::abs(waypoint.Distance(state.first.Translation()).value()) <
1E-9) {
found = true;
}
@@ -87,10 +83,10 @@
EXPECT_TRUE(interiorsGood);
// Check last point.
- EXPECT_NEAR(poses.back().first.X().to<double>(), b.X().to<double>(), 1E-9);
- EXPECT_NEAR(poses.back().first.Y().to<double>(), b.Y().to<double>(), 1E-9);
- EXPECT_NEAR(poses.back().first.Rotation().Radians().to<double>(),
- b.Rotation().Radians().to<double>(), 1E-9);
+ EXPECT_NEAR(poses.back().first.X().value(), b.X().value(), 1E-9);
+ EXPECT_NEAR(poses.back().first.Y().value(), b.Y().value(), 1E-9);
+ EXPECT_NEAR(poses.back().first.Rotation().Radians().value(),
+ b.Rotation().Radians().value(), 1E-9);
static_cast<void>(duration);
}
diff --git a/wpimath/src/test/native/cpp/spline/QuinticHermiteSplineTest.cpp b/wpimath/src/test/native/cpp/spline/QuinticHermiteSplineTest.cpp
index 30b5b31..25449fb 100644
--- a/wpimath/src/test/native/cpp/spline/QuinticHermiteSplineTest.cpp
+++ b/wpimath/src/test/native/cpp/spline/QuinticHermiteSplineTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <chrono>
#include <iostream>
@@ -43,25 +40,25 @@
// Make sure the twist is under the tolerance defined by the Spline class.
auto twist = p0.first.Log(p1.first);
- EXPECT_LT(std::abs(twist.dx.to<double>()),
- SplineParameterizer::kMaxDx.to<double>());
- EXPECT_LT(std::abs(twist.dy.to<double>()),
- SplineParameterizer::kMaxDy.to<double>());
- EXPECT_LT(std::abs(twist.dtheta.to<double>()),
- SplineParameterizer::kMaxDtheta.to<double>());
+ EXPECT_LT(std::abs(twist.dx.value()),
+ SplineParameterizer::kMaxDx.value());
+ EXPECT_LT(std::abs(twist.dy.value()),
+ SplineParameterizer::kMaxDy.value());
+ EXPECT_LT(std::abs(twist.dtheta.value()),
+ SplineParameterizer::kMaxDtheta.value());
}
// Check first point.
- EXPECT_NEAR(poses.front().first.X().to<double>(), a.X().to<double>(), 1E-9);
- EXPECT_NEAR(poses.front().first.Y().to<double>(), a.Y().to<double>(), 1E-9);
- EXPECT_NEAR(poses.front().first.Rotation().Radians().to<double>(),
- a.Rotation().Radians().to<double>(), 1E-9);
+ EXPECT_NEAR(poses.front().first.X().value(), a.X().value(), 1E-9);
+ EXPECT_NEAR(poses.front().first.Y().value(), a.Y().value(), 1E-9);
+ EXPECT_NEAR(poses.front().first.Rotation().Radians().value(),
+ a.Rotation().Radians().value(), 1E-9);
// Check last point.
- EXPECT_NEAR(poses.back().first.X().to<double>(), b.X().to<double>(), 1E-9);
- EXPECT_NEAR(poses.back().first.Y().to<double>(), b.Y().to<double>(), 1E-9);
- EXPECT_NEAR(poses.back().first.Rotation().Radians().to<double>(),
- b.Rotation().Radians().to<double>(), 1E-9);
+ EXPECT_NEAR(poses.back().first.X().value(), b.X().value(), 1E-9);
+ EXPECT_NEAR(poses.back().first.Y().value(), b.Y().value(), 1E-9);
+ EXPECT_NEAR(poses.back().first.Rotation().Radians().value(),
+ b.Rotation().Radians().value(), 1E-9);
static_cast<void>(duration);
}
diff --git a/wpimath/src/test/native/cpp/system/DiscretizationTest.cpp b/wpimath/src/test/native/cpp/system/DiscretizationTest.cpp
index dbeb518..b5a0fdc 100644
--- a/wpimath/src/test/native/cpp/system/DiscretizationTest.cpp
+++ b/wpimath/src/test/native/cpp/system/DiscretizationTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
@@ -12,25 +9,23 @@
#include "Eigen/Core"
#include "Eigen/Eigenvalues"
#include "frc/system/Discretization.h"
-#include "frc/system/RungeKutta.h"
+#include "frc/system/NumericalIntegration.h"
+#include "frc/system/RungeKuttaTimeVarying.h"
// Check that for a simple second-order system that we can easily analyze
// analytically,
TEST(DiscretizationTest, DiscretizeA) {
- Eigen::Matrix<double, 2, 2> contA;
- contA << 0, 1, 0, 0;
+ Eigen::Matrix<double, 2, 2> contA{{0, 1}, {0, 0}};
- Eigen::Matrix<double, 2, 1> x0;
- x0 << 1, 1;
+ Eigen::Vector<double, 2> x0{1, 1};
Eigen::Matrix<double, 2, 2> discA;
frc::DiscretizeA<2>(contA, 1_s, &discA);
- Eigen::Matrix<double, 2, 1> x1Discrete = discA * x0;
+ Eigen::Vector<double, 2> x1Discrete = discA * x0;
// We now have pos = vel = 1 and accel = 0, which should give us:
- Eigen::Matrix<double, 2, 1> x1Truth;
- x1Truth(1) = x0(1);
- x1Truth(0) = x0(0) + 1.0 * x0(1);
+ Eigen::Vector<double, 2> x1Truth{1.0 * x0(0) + 1.0 * x0(1),
+ 0.0 * x0(0) + 1.0 * x0(1)};
EXPECT_EQ(x1Truth, x1Discrete);
}
@@ -38,38 +33,30 @@
// Check that for a simple second-order system that we can easily analyze
// analytically,
TEST(DiscretizationTest, DiscretizeAB) {
- Eigen::Matrix<double, 2, 2> contA;
- contA << 0, 1, 0, 0;
+ Eigen::Matrix<double, 2, 2> contA{{0, 1}, {0, 0}};
+ Eigen::Matrix<double, 2, 1> contB{0, 1};
- Eigen::Matrix<double, 2, 1> contB;
- contB << 0, 1;
-
- Eigen::Matrix<double, 2, 1> x0;
- x0 << 1, 1;
- Eigen::Matrix<double, 1, 1> u;
- u << 1;
+ Eigen::Vector<double, 2> x0{1, 1};
+ Eigen::Vector<double, 1> u{1};
Eigen::Matrix<double, 2, 2> discA;
Eigen::Matrix<double, 2, 1> discB;
frc::DiscretizeAB<2, 1>(contA, contB, 1_s, &discA, &discB);
- Eigen::Matrix<double, 2, 1> x1Discrete = discA * x0 + discB * u;
+ Eigen::Vector<double, 2> x1Discrete = discA * x0 + discB * u;
// We now have pos = vel = accel = 1, which should give us:
- Eigen::Matrix<double, 2, 1> x1Truth;
- x1Truth(1) = x0(1) + 1.0 * u(0);
- x1Truth(0) = x0(0) + 1.0 * x0(1) + 0.5 * u(0);
+ Eigen::Vector<double, 2> x1Truth{1.0 * x0(0) + 1.0 * x0(1) + 0.5 * u(0),
+ 0.0 * x0(0) + 1.0 * x0(1) + 1.0 * u(0)};
EXPECT_EQ(x1Truth, x1Discrete);
}
-// Test that the discrete approximation of Q is roughly equal to
-// integral from 0 to dt of e^(A tau) Q e^(A.T tau) dtau
+// dt
+// Test that the discrete approximation of Q ≈ ∫ e^(Aτ) Q e^(Aᵀτ) dτ
+// 0
TEST(DiscretizationTest, DiscretizeSlowModelAQ) {
- Eigen::Matrix<double, 2, 2> contA;
- contA << 0, 1, 0, 0;
-
- Eigen::Matrix<double, 2, 2> contQ;
- contQ << 1, 0, 0, 1;
+ Eigen::Matrix<double, 2, 2> contA{{0, 1}, {0, 0}};
+ Eigen::Matrix<double, 2, 2> contQ{{1, 0}, {0, 1}};
constexpr auto dt = 1_s;
@@ -79,10 +66,10 @@
Eigen::Matrix<double, 2, 2>>(
[&](units::second_t t, const Eigen::Matrix<double, 2, 2>&) {
return Eigen::Matrix<double, 2, 2>(
- (contA * t.to<double>()).exp() * contQ *
- (contA.transpose() * t.to<double>()).exp());
+ (contA * t.value()).exp() * contQ *
+ (contA.transpose() * t.value()).exp());
},
- Eigen::Matrix<double, 2, 2>::Zero(), 0_s, dt);
+ 0_s, Eigen::Matrix<double, 2, 2>::Zero(), dt);
Eigen::Matrix<double, 2, 2> discA;
Eigen::Matrix<double, 2, 2> discQ;
@@ -94,16 +81,14 @@
<< discQIntegrated;
}
-// Test that the discrete approximation of Q is roughly equal to
-// integral from 0 to dt of e^(A tau) Q e^(A.T tau) dtau
+// dt
+// Test that the discrete approximation of Q ≈ ∫ e^(Aτ) Q e^(Aᵀτ) dτ
+// 0
TEST(DiscretizationTest, DiscretizeFastModelAQ) {
- Eigen::Matrix<double, 2, 2> contA;
- contA << 0, 1, 0, -1406.29;
+ Eigen::Matrix<double, 2, 2> contA{{0, 1}, {0, -1406.29}};
+ Eigen::Matrix<double, 2, 2> contQ{{0.0025, 0}, {0, 1}};
- Eigen::Matrix<double, 2, 2> contQ;
- contQ << 0.0025, 0, 0, 1;
-
- constexpr auto dt = 5.05_ms;
+ constexpr auto dt = 5_ms;
Eigen::Matrix<double, 2, 2> discQIntegrated = frc::RungeKuttaTimeVarying<
std::function<Eigen::Matrix<double, 2, 2>(
@@ -111,10 +96,10 @@
Eigen::Matrix<double, 2, 2>>(
[&](units::second_t t, const Eigen::Matrix<double, 2, 2>&) {
return Eigen::Matrix<double, 2, 2>(
- (contA * t.to<double>()).exp() * contQ *
- (contA.transpose() * t.to<double>()).exp());
+ (contA * t.value()).exp() * contQ *
+ (contA.transpose() * t.value()).exp());
},
- Eigen::Matrix<double, 2, 2>::Zero(), 0_s, dt);
+ 0_s, Eigen::Matrix<double, 2, 2>::Zero(), dt);
Eigen::Matrix<double, 2, 2> discA;
Eigen::Matrix<double, 2, 2> discQ;
@@ -128,26 +113,19 @@
// Test that the Taylor series discretization produces nearly identical results.
TEST(DiscretizationTest, DiscretizeSlowModelAQTaylor) {
- Eigen::Matrix<double, 2, 2> contA;
- contA << 0, 1, 0, 0;
-
- Eigen::Matrix<double, 2, 1> contB;
- contB << 0, 1;
-
- Eigen::Matrix<double, 2, 2> contQ;
- contQ << 1, 0, 0, 1;
+ Eigen::Matrix<double, 2, 2> contA{{0, 1}, {0, 0}};
+ Eigen::Matrix<double, 2, 2> contQ{{1, 0}, {0, 1}};
constexpr auto dt = 1_s;
Eigen::Matrix<double, 2, 2> discQTaylor;
Eigen::Matrix<double, 2, 2> discA;
Eigen::Matrix<double, 2, 2> discATaylor;
- Eigen::Matrix<double, 2, 1> discB;
// Continuous Q should be positive semidefinite
- Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esCont(contQ);
- for (int i = 0; i < contQ.rows(); i++) {
- EXPECT_GT(esCont.eigenvalues()[i], 0);
+ Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esCont{contQ};
+ for (int i = 0; i < contQ.rows(); ++i) {
+ EXPECT_GE(esCont.eigenvalues()[i], 0);
}
Eigen::Matrix<double, 2, 2> discQIntegrated = frc::RungeKuttaTimeVarying<
@@ -156,12 +134,12 @@
Eigen::Matrix<double, 2, 2>>(
[&](units::second_t t, const Eigen::Matrix<double, 2, 2>&) {
return Eigen::Matrix<double, 2, 2>(
- (contA * t.to<double>()).exp() * contQ *
- (contA.transpose() * t.to<double>()).exp());
+ (contA * t.value()).exp() * contQ *
+ (contA.transpose() * t.value()).exp());
},
- Eigen::Matrix<double, 2, 2>::Zero(), 0_s, dt);
+ 0_s, Eigen::Matrix<double, 2, 2>::Zero(), dt);
- frc::DiscretizeAB<2, 1>(contA, contB, dt, &discA, &discB);
+ frc::DiscretizeA<2>(contA, dt, &discA);
frc::DiscretizeAQTaylor<2>(contA, contQ, dt, &discATaylor, &discQTaylor);
EXPECT_LT((discQIntegrated - discQTaylor).norm(), 1e-10)
@@ -171,34 +149,27 @@
EXPECT_LT((discA - discATaylor).norm(), 1e-10);
// Discrete Q should be positive semidefinite
- Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esDisc(discQTaylor);
- for (int i = 0; i < discQTaylor.rows(); i++) {
- EXPECT_GT(esDisc.eigenvalues()[i], 0);
+ Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esDisc{discQTaylor};
+ for (int i = 0; i < discQTaylor.rows(); ++i) {
+ EXPECT_GE(esDisc.eigenvalues()[i], 0);
}
}
// Test that the Taylor series discretization produces nearly identical results.
TEST(DiscretizationTest, DiscretizeFastModelAQTaylor) {
- Eigen::Matrix<double, 2, 2> contA;
- contA << 0, 1, 0, -1500;
+ Eigen::Matrix<double, 2, 2> contA{{0, 1}, {0, -1500}};
+ Eigen::Matrix<double, 2, 2> contQ{{0.0025, 0}, {0, 1}};
- Eigen::Matrix<double, 2, 1> contB;
- contB << 0, 1;
-
- Eigen::Matrix<double, 2, 2> contQ;
- contQ << 0.0025, 0, 0, 1;
-
- constexpr auto dt = 5.05_ms;
+ constexpr auto dt = 5_ms;
Eigen::Matrix<double, 2, 2> discQTaylor;
Eigen::Matrix<double, 2, 2> discA;
Eigen::Matrix<double, 2, 2> discATaylor;
- Eigen::Matrix<double, 2, 1> discB;
// Continuous Q should be positive semidefinite
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esCont(contQ);
- for (int i = 0; i < contQ.rows(); i++) {
- EXPECT_GT(esCont.eigenvalues()[i], 0);
+ for (int i = 0; i < contQ.rows(); ++i) {
+ EXPECT_GE(esCont.eigenvalues()[i], 0);
}
Eigen::Matrix<double, 2, 2> discQIntegrated = frc::RungeKuttaTimeVarying<
@@ -207,12 +178,12 @@
Eigen::Matrix<double, 2, 2>>(
[&](units::second_t t, const Eigen::Matrix<double, 2, 2>&) {
return Eigen::Matrix<double, 2, 2>(
- (contA * t.to<double>()).exp() * contQ *
- (contA.transpose() * t.to<double>()).exp());
+ (contA * t.value()).exp() * contQ *
+ (contA.transpose() * t.value()).exp());
},
- Eigen::Matrix<double, 2, 2>::Zero(), 0_s, dt);
+ 0_s, Eigen::Matrix<double, 2, 2>::Zero(), dt);
- frc::DiscretizeAB<2, 1>(contA, contB, dt, &discA, &discB);
+ frc::DiscretizeA<2>(contA, dt, &discA);
frc::DiscretizeAQTaylor<2>(contA, contQ, dt, &discATaylor, &discQTaylor);
EXPECT_LT((discQIntegrated - discQTaylor).norm(), 1e-3)
@@ -223,18 +194,15 @@
// Discrete Q should be positive semidefinite
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esDisc(discQTaylor);
- for (int i = 0; i < discQTaylor.rows(); i++) {
- EXPECT_GT(esDisc.eigenvalues()[i], 0);
+ for (int i = 0; i < discQTaylor.rows(); ++i) {
+ EXPECT_GE(esDisc.eigenvalues()[i], 0);
}
}
// Test that DiscretizeR() works
TEST(DiscretizationTest, DiscretizeR) {
- Eigen::Matrix<double, 2, 2> contR;
- contR << 2.0, 0.0, 0.0, 1.0;
-
- Eigen::Matrix<double, 2, 2> discRTruth;
- discRTruth << 4.0, 0.0, 0.0, 2.0;
+ Eigen::Matrix<double, 2, 2> contR{{2.0, 0.0}, {0.0, 1.0}};
+ Eigen::Matrix<double, 2, 2> discRTruth{{4.0, 0.0}, {0.0, 2.0}};
Eigen::Matrix<double, 2, 2> discR = frc::DiscretizeR<2>(contR, 500_ms);
diff --git a/wpimath/src/test/native/cpp/system/LinearSystemIDTest.cpp b/wpimath/src/test/native/cpp/system/LinearSystemIDTest.cpp
index eedf8c0..1fa12c2 100644
--- a/wpimath/src/test/native/cpp/system/LinearSystemIDTest.cpp
+++ b/wpimath/src/test/native/cpp/system/LinearSystemIDTest.cpp
@@ -1,16 +1,12 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <frc/system/LinearSystem.h>
#include <frc/system/plant/DCMotor.h>
#include <frc/system/plant/LinearSystemId.h>
#include <gtest/gtest.h>
-#include "frc/StateSpaceUtil.h"
#include "frc/system/plant/LinearSystemId.h"
#include "units/length.h"
#include "units/mass.h"
@@ -20,32 +16,37 @@
frc::DCMotor::NEO(4), 70_kg, 0.05_m, 0.4_m, 6.0_kg_sq_m, 6.0);
ASSERT_TRUE(model.A().isApprox(
- frc::MakeMatrix<2, 2>(-10.14132, 3.06598, 3.06598, -10.14132), 0.001));
+ Eigen::Matrix<double, 2, 2>{{-10.14132, 3.06598}, {3.06598, -10.14132}},
+ 0.001));
ASSERT_TRUE(model.B().isApprox(
- frc::MakeMatrix<2, 2>(4.2590, -1.28762, -1.2876, 4.2590), 0.001));
- ASSERT_TRUE(
- model.C().isApprox(frc::MakeMatrix<2, 2>(1.0, 0.0, 0.0, 1.0), 0.001));
- ASSERT_TRUE(
- model.D().isApprox(frc::MakeMatrix<2, 2>(0.0, 0.0, 0.0, 0.0), 0.001));
+ Eigen::Matrix<double, 2, 2>{{4.2590, -1.28762}, {-1.2876, 4.2590}},
+ 0.001));
+ ASSERT_TRUE(model.C().isApprox(
+ Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 1.0}}, 0.001));
+ ASSERT_TRUE(model.D().isApprox(
+ Eigen::Matrix<double, 2, 2>{{0.0, 0.0}, {0.0, 0.0}}, 0.001));
}
TEST(LinearSystemIDTest, ElevatorSystem) {
auto model = frc::LinearSystemId::ElevatorSystem(frc::DCMotor::NEO(2), 5_kg,
0.05_m, 12);
ASSERT_TRUE(model.A().isApprox(
- frc::MakeMatrix<2, 2>(0.0, 1.0, 0.0, -99.05473), 0.001));
- ASSERT_TRUE(model.B().isApprox(frc::MakeMatrix<2, 1>(0.0, 20.8), 0.001));
- ASSERT_TRUE(model.C().isApprox(frc::MakeMatrix<1, 2>(1.0, 0.0), 0.001));
- ASSERT_TRUE(model.D().isApprox(frc::MakeMatrix<1, 1>(0.0), 0.001));
+ Eigen::Matrix<double, 2, 2>{{0.0, 1.0}, {0.0, -99.05473}}, 0.001));
+ ASSERT_TRUE(
+ model.B().isApprox(Eigen::Matrix<double, 2, 1>{0.0, 20.8}, 0.001));
+ ASSERT_TRUE(model.C().isApprox(Eigen::Matrix<double, 1, 2>{1.0, 0.0}, 0.001));
+ ASSERT_TRUE(model.D().isApprox(Eigen::Matrix<double, 1, 1>{0.0}, 0.001));
}
TEST(LinearSystemIDTest, FlywheelSystem) {
auto model = frc::LinearSystemId::FlywheelSystem(frc::DCMotor::NEO(2),
0.00032_kg_sq_m, 1.0);
- ASSERT_TRUE(model.A().isApprox(frc::MakeMatrix<1, 1>(-26.87032), 0.001));
- ASSERT_TRUE(model.B().isApprox(frc::MakeMatrix<1, 1>(1354.166667), 0.001));
- ASSERT_TRUE(model.C().isApprox(frc::MakeMatrix<1, 1>(1.0), 0.001));
- ASSERT_TRUE(model.D().isApprox(frc::MakeMatrix<1, 1>(0.0), 0.001));
+ ASSERT_TRUE(
+ model.A().isApprox(Eigen::Matrix<double, 1, 1>{-26.87032}, 0.001));
+ ASSERT_TRUE(
+ model.B().isApprox(Eigen::Matrix<double, 1, 1>{1354.166667}, 0.001));
+ ASSERT_TRUE(model.C().isApprox(Eigen::Matrix<double, 1, 1>{1.0}, 0.001));
+ ASSERT_TRUE(model.D().isApprox(Eigen::Matrix<double, 1, 1>{0.0}, 0.001));
}
TEST(LinearSystemIDTest, IdentifyPositionSystem) {
@@ -56,9 +57,10 @@
auto model = frc::LinearSystemId::IdentifyPositionSystem<units::meter>(
kv * 1_V / 1_mps, ka * 1_V / 1_mps_sq);
- ASSERT_TRUE(model.A().isApprox(frc::MakeMatrix<2, 2>(0.0, 1.0, 0.0, -kv / ka),
- 0.001));
- ASSERT_TRUE(model.B().isApprox(frc::MakeMatrix<2, 1>(0.0, 1.0 / ka), 0.001));
+ ASSERT_TRUE(model.A().isApprox(
+ Eigen::Matrix<double, 2, 2>{{0.0, 1.0}, {0.0, -kv / ka}}, 0.001));
+ ASSERT_TRUE(
+ model.B().isApprox(Eigen::Matrix<double, 2, 1>{0.0, 1.0 / ka}, 0.001));
}
TEST(LinearSystemIDTest, IdentifyVelocitySystem) {
@@ -70,6 +72,6 @@
auto model = frc::LinearSystemId::IdentifyVelocitySystem<units::meter>(
kv * 1_V / 1_mps, ka * 1_V / 1_mps_sq);
- ASSERT_TRUE(model.A().isApprox(frc::MakeMatrix<1, 1>(-kv / ka), 0.001));
- ASSERT_TRUE(model.B().isApprox(frc::MakeMatrix<1, 1>(1.0 / ka), 0.001));
+ ASSERT_TRUE(model.A().isApprox(Eigen::Matrix<double, 1, 1>{-kv / ka}, 0.001));
+ ASSERT_TRUE(model.B().isApprox(Eigen::Matrix<double, 1, 1>{1.0 / ka}, 0.001));
}
diff --git a/wpimath/src/test/native/cpp/system/NumericalIntegrationTest.cpp b/wpimath/src/test/native/cpp/system/NumericalIntegrationTest.cpp
new file mode 100644
index 0000000..fd9c039
--- /dev/null
+++ b/wpimath/src/test/native/cpp/system/NumericalIntegrationTest.cpp
@@ -0,0 +1,57 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <gtest/gtest.h>
+
+#include <cmath>
+
+#include "frc/system/NumericalIntegration.h"
+
+// Tests that integrating dx/dt = e^x works.
+TEST(NumericalIntegrationTest, Exponential) {
+ Eigen::Vector<double, 1> y0{0.0};
+
+ Eigen::Vector<double, 1> y1 = frc::RK4(
+ [](const Eigen::Vector<double, 1>& x) {
+ return Eigen::Vector<double, 1>{std::exp(x(0))};
+ },
+ y0, 0.1_s);
+ EXPECT_NEAR(y1(0), std::exp(0.1) - std::exp(0), 1e-3);
+}
+
+// Tests that integrating dx/dt = e^x works when we provide a U.
+TEST(NumericalIntegrationTest, ExponentialWithU) {
+ Eigen::Vector<double, 1> y0{0.0};
+
+ Eigen::Vector<double, 1> y1 = frc::RK4(
+ [](const Eigen::Vector<double, 1>& x, const Eigen::Vector<double, 1>& u) {
+ return Eigen::Vector<double, 1>{std::exp(u(0) * x(0))};
+ },
+ y0, Eigen::Vector<double, 1>{1.0}, 0.1_s);
+ EXPECT_NEAR(y1(0), std::exp(0.1) - std::exp(0), 1e-3);
+}
+
+// Tests that integrating dx/dt = e^x works with RKF45.
+TEST(NumericalIntegrationTest, ExponentialRKF45) {
+ Eigen::Vector<double, 1> y0{0.0};
+
+ Eigen::Vector<double, 1> y1 = frc::RKF45(
+ [](const Eigen::Vector<double, 1>& x, const Eigen::Vector<double, 1>& u) {
+ return Eigen::Vector<double, 1>{std::exp(x(0))};
+ },
+ y0, Eigen::Vector<double, 1>{0.0}, 0.1_s);
+ EXPECT_NEAR(y1(0), std::exp(0.1) - std::exp(0), 1e-3);
+}
+
+// Tests that integrating dx/dt = e^x works with RKDP
+TEST(NumericalIntegrationTest, ExponentialRKDP) {
+ Eigen::Vector<double, 1> y0{0.0};
+
+ Eigen::Vector<double, 1> y1 = frc::RKDP(
+ [](const Eigen::Vector<double, 1>& x, const Eigen::Vector<double, 1>& u) {
+ return Eigen::Vector<double, 1>{std::exp(x(0))};
+ },
+ y0, Eigen::Vector<double, 1>{0.0}, 0.1_s);
+ EXPECT_NEAR(y1(0), std::exp(0.1) - std::exp(0), 1e-3);
+}
diff --git a/wpimath/src/test/native/cpp/system/NumericalJacobianTest.cpp b/wpimath/src/test/native/cpp/system/NumericalJacobianTest.cpp
index ddc3f68..4e64825 100644
--- a/wpimath/src/test/native/cpp/system/NumericalJacobianTest.cpp
+++ b/wpimath/src/test/native/cpp/system/NumericalJacobianTest.cpp
@@ -1,68 +1,58 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <gtest/gtest.h>
#include "frc/system/NumericalJacobian.h"
-Eigen::Matrix<double, 4, 4> A = (Eigen::Matrix<double, 4, 4>() << 1, 2, 4, 1, 5,
- 2, 3, 4, 5, 1, 3, 2, 1, 1, 3, 7)
- .finished();
-
-Eigen::Matrix<double, 4, 2> B =
- (Eigen::Matrix<double, 4, 2>() << 1, 1, 2, 1, 3, 2, 3, 7).finished();
+Eigen::Matrix<double, 4, 4> A{
+ {1, 2, 4, 1}, {5, 2, 3, 4}, {5, 1, 3, 2}, {1, 1, 3, 7}};
+Eigen::Matrix<double, 4, 2> B{{1, 1}, {2, 1}, {3, 2}, {3, 7}};
// Function from which to recover A and B
-Eigen::Matrix<double, 4, 1> AxBuFn(const Eigen::Matrix<double, 4, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 4> AxBuFn(const Eigen::Vector<double, 4>& x,
+ const Eigen::Vector<double, 2>& u) {
return A * x + B * u;
}
// Test that we can recover A from AxBuFn() pretty accurately
TEST(NumericalJacobianTest, Ax) {
- Eigen::Matrix<double, 4, 4> newA = frc::NumericalJacobianX<4, 4, 2>(
- AxBuFn, Eigen::Matrix<double, 4, 1>::Zero(),
- Eigen::Matrix<double, 2, 1>::Zero());
+ Eigen::Matrix<double, 4, 4> newA =
+ frc::NumericalJacobianX<4, 4, 2>(AxBuFn, Eigen::Vector<double, 4>::Zero(),
+ Eigen::Vector<double, 2>::Zero());
EXPECT_TRUE(newA.isApprox(A));
}
// Test that we can recover B from AxBuFn() pretty accurately
TEST(NumericalJacobianTest, Bu) {
- Eigen::Matrix<double, 4, 2> newB = frc::NumericalJacobianU<4, 4, 2>(
- AxBuFn, Eigen::Matrix<double, 4, 1>::Zero(),
- Eigen::Matrix<double, 2, 1>::Zero());
+ Eigen::Matrix<double, 4, 2> newB =
+ frc::NumericalJacobianU<4, 4, 2>(AxBuFn, Eigen::Vector<double, 4>::Zero(),
+ Eigen::Vector<double, 2>::Zero());
EXPECT_TRUE(newB.isApprox(B));
}
-Eigen::Matrix<double, 3, 4> C =
- (Eigen::Matrix<double, 3, 4>() << 1, 2, 4, 1, 5, 2, 3, 4, 5, 1, 3, 2)
- .finished();
-
-Eigen::Matrix<double, 3, 2> D =
- (Eigen::Matrix<double, 3, 2>() << 1, 1, 2, 1, 3, 2).finished();
+Eigen::Matrix<double, 3, 4> C{{1, 2, 4, 1}, {5, 2, 3, 4}, {5, 1, 3, 2}};
+Eigen::Matrix<double, 3, 2> D{{1, 1}, {2, 1}, {3, 2}};
// Function from which to recover C and D
-Eigen::Matrix<double, 3, 1> CxDuFn(const Eigen::Matrix<double, 4, 1>& x,
- const Eigen::Matrix<double, 2, 1>& u) {
+Eigen::Vector<double, 3> CxDuFn(const Eigen::Vector<double, 4>& x,
+ const Eigen::Vector<double, 2>& u) {
return C * x + D * u;
}
// Test that we can recover C from CxDuFn() pretty accurately
TEST(NumericalJacobianTest, Cx) {
- Eigen::Matrix<double, 3, 4> newC = frc::NumericalJacobianX<3, 4, 2>(
- CxDuFn, Eigen::Matrix<double, 4, 1>::Zero(),
- Eigen::Matrix<double, 2, 1>::Zero());
+ Eigen::Matrix<double, 3, 4> newC =
+ frc::NumericalJacobianX<3, 4, 2>(CxDuFn, Eigen::Vector<double, 4>::Zero(),
+ Eigen::Vector<double, 2>::Zero());
EXPECT_TRUE(newC.isApprox(C));
}
// Test that we can recover D from CxDuFn() pretty accurately
TEST(NumericalJacobianTest, Du) {
- Eigen::Matrix<double, 3, 2> newD = frc::NumericalJacobianU<3, 4, 2>(
- CxDuFn, Eigen::Matrix<double, 4, 1>::Zero(),
- Eigen::Matrix<double, 2, 1>::Zero());
+ Eigen::Matrix<double, 3, 2> newD =
+ frc::NumericalJacobianU<3, 4, 2>(CxDuFn, Eigen::Vector<double, 4>::Zero(),
+ Eigen::Vector<double, 2>::Zero());
EXPECT_TRUE(newD.isApprox(D));
}
diff --git a/wpimath/src/test/native/cpp/system/RungeKuttaTest.cpp b/wpimath/src/test/native/cpp/system/RungeKuttaTest.cpp
deleted file mode 100644
index a12c1b7..0000000
--- a/wpimath/src/test/native/cpp/system/RungeKuttaTest.cpp
+++ /dev/null
@@ -1,71 +0,0 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
-
-#include <gtest/gtest.h>
-
-#include <cmath>
-
-#include "frc/system/RungeKutta.h"
-
-// Tests that integrating dx/dt = e^x works.
-TEST(RungeKuttaTest, Exponential) {
- Eigen::Matrix<double, 1, 1> y0;
- y0(0) = 0.0;
-
- Eigen::Matrix<double, 1, 1> y1 = frc::RungeKutta(
- [](Eigen::Matrix<double, 1, 1> x) {
- Eigen::Matrix<double, 1, 1> y;
- y(0) = std::exp(x(0));
- return y;
- },
- y0, 0.1_s);
- EXPECT_NEAR(y1(0), std::exp(0.1) - std::exp(0), 1e-3);
-}
-
-// Tests that integrating dx/dt = e^x works when we provide a U.
-TEST(RungeKuttaTest, ExponentialWithU) {
- Eigen::Matrix<double, 1, 1> y0;
- y0(0) = 0.0;
-
- Eigen::Matrix<double, 1, 1> y1 = frc::RungeKutta(
- [](Eigen::Matrix<double, 1, 1> x, Eigen::Matrix<double, 1, 1> u) {
- Eigen::Matrix<double, 1, 1> y;
- y(0) = std::exp(u(0) * x(0));
- return y;
- },
- y0, (Eigen::Matrix<double, 1, 1>() << 1.0).finished(), 0.1_s);
- EXPECT_NEAR(y1(0), std::exp(0.1) - std::exp(0), 1e-3);
-}
-
-namespace {
-Eigen::Matrix<double, 1, 1> RungeKuttaTimeVaryingSolution(double t) {
- return (Eigen::Matrix<double, 1, 1>()
- << 12.0 * std::exp(t) / (std::pow(std::exp(t) + 1.0, 2.0)))
- .finished();
-}
-} // namespace
-
-// Tests RungeKutta with a time varying solution.
-// Now, lets test RK4 with a time varying solution. From
-// http://www2.hawaii.edu/~jmcfatri/math407/RungeKuttaTest.html:
-// x' = x (2 / (e^t + 1) - 1)
-//
-// The true (analytical) solution is:
-//
-// x(t) = 12 * e^t / ((e^t + 1)^2)
-TEST(RungeKuttaTest, RungeKuttaTimeVarying) {
- Eigen::Matrix<double, 1, 1> y0 = RungeKuttaTimeVaryingSolution(5.0);
-
- Eigen::Matrix<double, 1, 1> y1 = frc::RungeKuttaTimeVarying(
- [](units::second_t t, Eigen::Matrix<double, 1, 1> x) {
- return (Eigen::Matrix<double, 1, 1>()
- << x(0) * (2.0 / (std::exp(t.to<double>()) + 1.0) - 1.0))
- .finished();
- },
- y0, 5_s, 1_s);
- EXPECT_NEAR(y1(0), RungeKuttaTimeVaryingSolution(6.0)(0), 1e-3);
-}
diff --git a/wpimath/src/test/native/cpp/system/RungeKuttaTimeVaryingTest.cpp b/wpimath/src/test/native/cpp/system/RungeKuttaTimeVaryingTest.cpp
new file mode 100644
index 0000000..f1be861
--- /dev/null
+++ b/wpimath/src/test/native/cpp/system/RungeKuttaTimeVaryingTest.cpp
@@ -0,0 +1,35 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include <gtest/gtest.h>
+
+#include <cmath>
+
+#include "frc/system/RungeKuttaTimeVarying.h"
+
+namespace {
+Eigen::Vector<double, 1> RungeKuttaTimeVaryingSolution(double t) {
+ return Eigen::Vector<double, 1>{12.0 * std::exp(t) /
+ (std::pow(std::exp(t) + 1.0, 2.0))};
+}
+} // namespace
+
+// Tests RK4 with a time varying solution. From
+// http://www2.hawaii.edu/~jmcfatri/math407/RungeKuttaTest.html:
+// x' = x (2 / (e^t + 1) - 1)
+//
+// The true (analytical) solution is:
+//
+// x(t) = 12 * e^t / ((e^t + 1)^2)
+TEST(RungeKuttaTimeVaryingTest, RungeKuttaTimeVarying) {
+ Eigen::Vector<double, 1> y0 = RungeKuttaTimeVaryingSolution(5.0);
+
+ Eigen::Vector<double, 1> y1 = frc::RungeKuttaTimeVarying(
+ [](units::second_t t, const Eigen::Vector<double, 1>& x) {
+ return Eigen::Vector<double, 1>{
+ x(0) * (2.0 / (std::exp(t.value()) + 1.0) - 1.0)};
+ },
+ 5_s, y0, 1_s);
+ EXPECT_NEAR(y1(0), RungeKuttaTimeVaryingSolution(6.0)(0), 1e-3);
+}
diff --git a/wpimath/src/test/native/cpp/trajectory/CentripetalAccelerationConstraintTest.cpp b/wpimath/src/test/native/cpp/trajectory/CentripetalAccelerationConstraintTest.cpp
index 42e9fe2..e2f7112 100644
--- a/wpimath/src/test/native/cpp/trajectory/CentripetalAccelerationConstraintTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/CentripetalAccelerationConstraintTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <memory>
#include <vector>
diff --git a/wpimath/src/test/native/cpp/trajectory/DifferentialDriveKinematicsTest.cpp b/wpimath/src/test/native/cpp/trajectory/DifferentialDriveKinematicsTest.cpp
index 636b002..e3723b5 100644
--- a/wpimath/src/test/native/cpp/trajectory/DifferentialDriveKinematicsTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/DifferentialDriveKinematicsTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <memory>
#include <vector>
diff --git a/wpimath/src/test/native/cpp/trajectory/DifferentialDriveVoltageTest.cpp b/wpimath/src/test/native/cpp/trajectory/DifferentialDriveVoltageTest.cpp
index 6cd8075..b21ef7b 100644
--- a/wpimath/src/test/native/cpp/trajectory/DifferentialDriveVoltageTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/DifferentialDriveVoltageTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <iostream>
#include <memory>
diff --git a/wpimath/src/test/native/cpp/trajectory/EllipticalRegionConstraintTest.cpp b/wpimath/src/test/native/cpp/trajectory/EllipticalRegionConstraintTest.cpp
index 44c4222..88dc2b8 100644
--- a/wpimath/src/test/native/cpp/trajectory/EllipticalRegionConstraintTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/EllipticalRegionConstraintTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <vector>
diff --git a/wpimath/src/test/native/cpp/trajectory/RectangularRegionConstraintTest.cpp b/wpimath/src/test/native/cpp/trajectory/RectangularRegionConstraintTest.cpp
index 33f753f..77310ae 100644
--- a/wpimath/src/test/native/cpp/trajectory/RectangularRegionConstraintTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/RectangularRegionConstraintTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <vector>
diff --git a/wpimath/src/test/native/cpp/trajectory/TrajectoryConcatenateTest.cpp b/wpimath/src/test/native/cpp/trajectory/TrajectoryConcatenateTest.cpp
new file mode 100644
index 0000000..2b733a8
--- /dev/null
+++ b/wpimath/src/test/native/cpp/trajectory/TrajectoryConcatenateTest.cpp
@@ -0,0 +1,34 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#include "frc/trajectory/TrajectoryConfig.h"
+#include "frc/trajectory/TrajectoryGenerator.h"
+#include "gtest/gtest.h"
+
+TEST(TrajectoryConcatenateTest, States) {
+ auto t1 = frc::TrajectoryGenerator::GenerateTrajectory(
+ {}, {}, {1_m, 1_m, 0_deg}, {2_mps, 2_mps_sq});
+ auto t2 = frc::TrajectoryGenerator::GenerateTrajectory(
+ {1_m, 1_m, 0_deg}, {}, {2_m, 2_m, 45_deg}, {2_mps, 2_mps_sq});
+
+ auto t = t1 + t2;
+
+ double time = -1.0;
+ for (size_t i = 0; i < t.States().size(); ++i) {
+ const auto& state = t.States()[i];
+
+ // Make sure that the timestamps are strictly increasing.
+ EXPECT_GT(state.t.value(), time);
+ time = state.t.value();
+
+ // Ensure that the states in t are the same as those in t1 and t2.
+ if (i < t1.States().size()) {
+ EXPECT_EQ(state, t1.States()[i]);
+ } else {
+ auto st = t2.States()[i - t1.States().size() + 1];
+ st.t += t1.TotalTime();
+ EXPECT_EQ(state, st);
+ }
+ }
+}
diff --git a/wpimath/src/test/native/cpp/trajectory/TrajectoryGeneratorTest.cpp b/wpimath/src/test/native/cpp/trajectory/TrajectoryGeneratorTest.cpp
index 378aff7..175becf 100644
--- a/wpimath/src/test/native/cpp/trajectory/TrajectoryGeneratorTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/TrajectoryGeneratorTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <vector>
diff --git a/wpimath/src/test/native/cpp/trajectory/TrajectoryJsonTest.cpp b/wpimath/src/test/native/cpp/trajectory/TrajectoryJsonTest.cpp
index c18a3e9..90c6dc0 100644
--- a/wpimath/src/test/native/cpp/trajectory/TrajectoryJsonTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/TrajectoryJsonTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/TrajectoryConfig.h"
#include "frc/trajectory/TrajectoryUtil.h"
diff --git a/wpimath/src/test/native/cpp/trajectory/TrajectoryTransformTest.cpp b/wpimath/src/test/native/cpp/trajectory/TrajectoryTransformTest.cpp
index 349ff5c..0c6e07a 100644
--- a/wpimath/src/test/native/cpp/trajectory/TrajectoryTransformTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/TrajectoryTransformTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include <vector>
@@ -24,14 +21,14 @@
auto a = a2.RelativeTo(a1);
auto b = b2.RelativeTo(b1);
- EXPECT_NEAR(a.X().to<double>(), b.X().to<double>(), 1E-9);
- EXPECT_NEAR(a.Y().to<double>(), b.Y().to<double>(), 1E-9);
- EXPECT_NEAR(a.Rotation().Radians().to<double>(),
- b.Rotation().Radians().to<double>(), 1E-9);
+ EXPECT_NEAR(a.X().value(), b.X().value(), 1E-9);
+ EXPECT_NEAR(a.Y().value(), b.Y().value(), 1E-9);
+ EXPECT_NEAR(a.Rotation().Radians().value(), b.Rotation().Radians().value(),
+ 1E-9);
}
}
-TEST(TrajectoryTransforms, TransformBy) {
+TEST(TrajectoryTransformsTest, TransformBy) {
frc::TrajectoryConfig config{3_mps, 3_mps_sq};
auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
frc::Pose2d{}, {}, frc::Pose2d{1_m, 1_m, frc::Rotation2d(90_deg)},
@@ -42,14 +39,14 @@
auto firstPose = transformedTrajectory.Sample(0_s).pose;
- EXPECT_NEAR(firstPose.X().to<double>(), 1.0, 1E-9);
- EXPECT_NEAR(firstPose.Y().to<double>(), 2.0, 1E-9);
- EXPECT_NEAR(firstPose.Rotation().Degrees().to<double>(), 30.0, 1E-9);
+ EXPECT_NEAR(firstPose.X().value(), 1.0, 1E-9);
+ EXPECT_NEAR(firstPose.Y().value(), 2.0, 1E-9);
+ EXPECT_NEAR(firstPose.Rotation().Degrees().value(), 30.0, 1E-9);
TestSameShapedTrajectory(trajectory.States(), transformedTrajectory.States());
}
-TEST(TrajectoryTransforms, RelativeTo) {
+TEST(TrajectoryTransformsTest, RelativeTo) {
frc::TrajectoryConfig config{3_mps, 3_mps_sq};
auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory(
frc::Pose2d{1_m, 2_m, frc::Rotation2d(30_deg)}, {},
@@ -60,9 +57,9 @@
auto firstPose = transformedTrajectory.Sample(0_s).pose;
- EXPECT_NEAR(firstPose.X().to<double>(), 0, 1E-9);
- EXPECT_NEAR(firstPose.Y().to<double>(), 0, 1E-9);
- EXPECT_NEAR(firstPose.Rotation().Degrees().to<double>(), 0, 1E-9);
+ EXPECT_NEAR(firstPose.X().value(), 0, 1E-9);
+ EXPECT_NEAR(firstPose.Y().value(), 0, 1E-9);
+ EXPECT_NEAR(firstPose.Rotation().Degrees().value(), 0, 1E-9);
TestSameShapedTrajectory(trajectory.States(), transformedTrajectory.States());
}
diff --git a/wpimath/src/test/native/cpp/trajectory/TrapezoidProfileTest.cpp b/wpimath/src/test/native/cpp/trajectory/TrapezoidProfileTest.cpp
index 63ea916..6a35261 100644
--- a/wpimath/src/test/native/cpp/trajectory/TrapezoidProfileTest.cpp
+++ b/wpimath/src/test/native/cpp/trajectory/TrapezoidProfileTest.cpp
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#include "frc/trajectory/TrapezoidProfile.h" // NOLINT(build/include_order)
diff --git a/wpimath/src/test/native/include/drake/common/autodiff.h b/wpimath/src/test/native/include/drake/common/autodiff.h
deleted file mode 100644
index 66fd88a..0000000
--- a/wpimath/src/test/native/include/drake/common/autodiff.h
+++ /dev/null
@@ -1,34 +0,0 @@
-#pragma once
-/// @file This header provides a single inclusion point for autodiff-related
-/// header files in the `drake/common` directory. Users should include this
-/// file. Including other individual headers such as `drake/common/autodiffxd.h`
-/// will generate a compile-time warning.
-
-// In each header included below, it asserts that this macro
-// `DRAKE_COMMON_AUTODIFF_HEADER` is defined. If the macro is not defined, it
-// generates diagnostic warning messages.
-#define DRAKE_COMMON_AUTODIFF_HEADER
-
-#include <Eigen/Core>
-#include <unsupported/Eigen/AutoDiff>
-
-static_assert(EIGEN_VERSION_AT_LEAST(3, 3, 3),
- "Drake requires Eigen >= v3.3.3.");
-
-// Do not alpha-sort the following block of hard-coded #includes, which is
-// protected by `clang-format on/off`.
-//
-// Rationale: We want to maximize the use of this header, `autodiff.h`, even
-// inside of the autodiff-related files to avoid any mistakes which might not be
-// detected. By centralizing the list here, we make sure that everyone will see
-// the correct order which respects the inter-dependencies of the autodiff
-// headers. This shields us from triggering undefined behaviors due to
-// order-of-specialization-includes-changed mistakes.
-//
-// clang-format off
-#include "drake/common/eigen_autodiff_limits.h"
-#include "drake/common/eigen_autodiff_types.h"
-#include "drake/common/autodiffxd.h"
-#include "drake/common/autodiff_overloads.h"
-// clang-format on
-#undef DRAKE_COMMON_AUTODIFF_HEADER
diff --git a/wpimath/src/test/native/include/drake/common/autodiff_overloads.h b/wpimath/src/test/native/include/drake/common/autodiff_overloads.h
deleted file mode 100644
index 7eaeb3f..0000000
--- a/wpimath/src/test/native/include/drake/common/autodiff_overloads.h
+++ /dev/null
@@ -1,203 +0,0 @@
-/// @file
-/// Overloads for STL mathematical operations on AutoDiffScalar.
-///
-/// Used via argument-dependent lookup (ADL). These functions appear
-/// in the Eigen namespace so that ADL can automatically choose between
-/// the STL version and the overloaded version to match the type of the
-/// arguments. The proper use would be e.g.
-///
-/// \code{.cc}
-/// void mymethod() {
-/// using std::isinf;
-/// isinf(myval);
-/// }
-/// \endcode{}
-///
-/// @note The if_then_else and cond functions for AutoDiffScalar are in
-/// namespace drake because cond is defined in namespace drake in
-/// "drake/common/cond.h" file.
-
-#pragma once
-
-#ifndef DRAKE_COMMON_AUTODIFF_HEADER
-// TODO(soonho-tri): Change to #error.
-#warning Do not directly include this file. Include "drake/common/autodiff.h".
-#endif
-
-#include <cmath>
-#include <limits>
-
-#include "drake/common/cond.h"
-#include "drake/common/drake_assert.h"
-#include "drake/common/dummy_value.h"
-
-namespace Eigen {
-
-/// Overloads nexttoward to mimic std::nexttoward from <cmath>.
-template <typename DerType>
-double nexttoward(const Eigen::AutoDiffScalar<DerType>& from, long double to) {
- using std::nexttoward;
- return nexttoward(from.value(), to);
-}
-
-/// Overloads round to mimic std::round from <cmath>.
-template <typename DerType>
-double round(const Eigen::AutoDiffScalar<DerType>& x) {
- using std::round;
- return round(x.value());
-}
-
-/// Overloads isinf to mimic std::isinf from <cmath>.
-template <typename DerType>
-bool isinf(const Eigen::AutoDiffScalar<DerType>& x) {
- using std::isinf;
- return isinf(x.value());
-}
-
-/// Overloads isfinite to mimic std::isfinite from <cmath>.
-template <typename DerType>
-bool isfinite(const Eigen::AutoDiffScalar<DerType>& x) {
- using std::isfinite;
- return isfinite(x.value());
-}
-
-/// Overloads isnan to mimic std::isnan from <cmath>.
-template <typename DerType>
-bool isnan(const Eigen::AutoDiffScalar<DerType>& x) {
- using std::isnan;
- return isnan(x.value());
-}
-
-/// Overloads floor to mimic std::floor from <cmath>.
-template <typename DerType>
-double floor(const Eigen::AutoDiffScalar<DerType>& x) {
- using std::floor;
- return floor(x.value());
-}
-
-/// Overloads ceil to mimic std::ceil from <cmath>.
-template <typename DerType>
-double ceil(const Eigen::AutoDiffScalar<DerType>& x) {
- using std::ceil;
- return ceil(x.value());
-}
-
-/// Overloads copysign from <cmath>.
-template <typename DerType, typename T>
-Eigen::AutoDiffScalar<DerType> copysign(const Eigen::AutoDiffScalar<DerType>& x,
- const T& y) {
- using std::isnan;
- if (isnan(x)) return (y >= 0) ? NAN : -NAN;
- if ((x < 0 && y >= 0) || (x >= 0 && y < 0))
- return -x;
- else
- return x;
-}
-
-/// Overloads copysign from <cmath>.
-template <typename DerType>
-double copysign(double x, const Eigen::AutoDiffScalar<DerType>& y) {
- using std::isnan;
- if (isnan(x)) return (y >= 0) ? NAN : -NAN;
- if ((x < 0 && y >= 0) || (x >= 0 && y < 0))
- return -x;
- else
- return x;
-}
-
-/// Overloads pow for an AutoDiffScalar base and exponent, implementing the
-/// chain rule.
-template <typename DerTypeA, typename DerTypeB>
-Eigen::AutoDiffScalar<
- typename internal::remove_all<DerTypeA>::type::PlainObject>
-pow(const Eigen::AutoDiffScalar<DerTypeA>& base,
- const Eigen::AutoDiffScalar<DerTypeB>& exponent) {
- // The two AutoDiffScalars being exponentiated must have the same matrix
- // type. This includes, but is not limited to, the same scalar type and
- // the same dimension.
- static_assert(
- std::is_same<
- typename internal::remove_all<DerTypeA>::type::PlainObject,
- typename internal::remove_all<DerTypeB>::type::PlainObject>::value,
- "The derivative types must match.");
-
- internal::make_coherent(base.derivatives(), exponent.derivatives());
-
- const auto& x = base.value();
- const auto& xgrad = base.derivatives();
- const auto& y = exponent.value();
- const auto& ygrad = exponent.derivatives();
-
- using std::pow;
- using std::log;
- const auto x_to_the_y = pow(x, y);
- if (ygrad.isZero(std::numeric_limits<double>::epsilon()) ||
- ygrad.size() == 0) {
- // The derivative only depends on ∂(x^y)/∂x -- this prevents undefined
- // behavior in the corner case where ∂(x^y)/∂y is infinite when x = 0,
- // despite ∂y/∂v being 0.
- return Eigen::MakeAutoDiffScalar(x_to_the_y, y * pow(x, y - 1) * xgrad);
- }
- return Eigen::MakeAutoDiffScalar(
- // The value is x ^ y.
- x_to_the_y,
- // The multivariable chain rule states:
- // df/dv_i = (∂f/∂x * dx/dv_i) + (∂f/∂y * dy/dv_i)
- // ∂f/∂x is y*x^(y-1)
- y * pow(x, y - 1) * xgrad +
- // ∂f/∂y is (x^y)*ln(x)
- x_to_the_y * log(x) * ygrad);
-}
-
-} // namespace Eigen
-
-namespace drake {
-
-/// Returns the autodiff scalar's value() as a double. Never throws.
-/// Overloads ExtractDoubleOrThrow from common/extract_double.h.
-template <typename DerType>
-double ExtractDoubleOrThrow(const Eigen::AutoDiffScalar<DerType>& scalar) {
- return static_cast<double>(scalar.value());
-}
-
-/// Specializes common/dummy_value.h.
-template <typename DerType>
-struct dummy_value<Eigen::AutoDiffScalar<DerType>> {
- static constexpr Eigen::AutoDiffScalar<DerType> get() {
- constexpr double kNaN = std::numeric_limits<double>::quiet_NaN();
- DerType derivatives;
- derivatives.fill(kNaN);
- return Eigen::AutoDiffScalar<DerType>(kNaN, derivatives);
- }
-};
-
-/// Provides if-then-else expression for Eigen::AutoDiffScalar type. To support
-/// Eigen's generic expressions, we use casting to the plain object after
-/// applying Eigen::internal::remove_all. It is based on the Eigen's
-/// implementation of min/max function for AutoDiffScalar type
-/// (https://bitbucket.org/eigen/eigen/src/10a1de58614569c9250df88bdfc6402024687bc6/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h?at=default&fileviewer=file-view-default#AutoDiffScalar.h-546).
-template <typename DerType1, typename DerType2>
-inline Eigen::AutoDiffScalar<
- typename Eigen::internal::remove_all<DerType1>::type::PlainObject>
-if_then_else(bool f_cond, const Eigen::AutoDiffScalar<DerType1>& x,
- const Eigen::AutoDiffScalar<DerType2>& y) {
- typedef Eigen::AutoDiffScalar<
- typename Eigen::internal::remove_all<DerType1>::type::PlainObject>
- ADS1;
- typedef Eigen::AutoDiffScalar<
- typename Eigen::internal::remove_all<DerType2>::type::PlainObject>
- ADS2;
- static_assert(std::is_same<ADS1, ADS2>::value,
- "The derivative types must match.");
- return f_cond ? ADS1(x) : ADS2(y);
-}
-
-/// Provides special case of cond expression for Eigen::AutoDiffScalar type.
-template <typename DerType, typename... Rest>
-Eigen::AutoDiffScalar<
- typename Eigen::internal::remove_all<DerType>::type::PlainObject>
-cond(bool f_cond, const Eigen::AutoDiffScalar<DerType>& e_then, Rest... rest) {
- return if_then_else(f_cond, e_then, cond(rest...));
-}
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/autodiffxd.h b/wpimath/src/test/native/include/drake/common/autodiffxd.h
deleted file mode 100644
index a99b2f0..0000000
--- a/wpimath/src/test/native/include/drake/common/autodiffxd.h
+++ /dev/null
@@ -1,423 +0,0 @@
-#pragma once
-
-// This file is a modification of Eigen-3.3.3's AutoDiffScalar.h file which is
-// available at
-// https://bitbucket.org/eigen/eigen/raw/67e894c6cd8f5f1f604b27d37ed47fdf012674ff/unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
-//
-// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
-// Copyright (C) 2017 Drake Authors
-//
-// This Source Code Form is subject to the terms of the Mozilla
-// Public License v. 2.0. If a copy of the MPL was not distributed
-// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-#ifndef DRAKE_COMMON_AUTODIFF_HEADER
-// TODO(soonho-tri): Change to #error.
-#warning Do not directly include this file. Include "drake/common/autodiff.h".
-#endif
-
-#include <cmath>
-#include <ostream>
-
-#include <Eigen/Core>
-
-namespace Eigen {
-
-#if !defined(DRAKE_DOXYGEN_CXX)
-// Explicit template specializations of Eigen::AutoDiffScalar for VectorXd.
-//
-// AutoDiffScalar tries to call internal::make_coherent to promote empty
-// derivatives. However, it fails to do the promotion when an operand is an
-// expression tree (i.e. CwiseBinaryOp). Our solution is to provide special
-// overloading for VectorXd and change the return types of its operators. With
-// this change, the operators evaluate terms immediately and return an
-// AutoDiffScalar<VectorXd> instead of expression trees (such as CwiseBinaryOp).
-// Eigen's implementation of internal::make_coherent makes use of const_cast in
-// order to promote zero sized derivatives. This however interferes badly with
-// our caching system and produces unexpected behaviors. See #10971 for details.
-// Therefore our implementation stops using internal::make_coherent and treats
-// scalars with zero sized derivatives as constants, as it should.
-//
-// We also provide overloading of math functions for AutoDiffScalar<VectorXd>
-// which return AutoDiffScalar<VectorXd> instead of an expression tree.
-//
-// See https://github.com/RobotLocomotion/drake/issues/6944 for more
-// information. See also drake/common/autodiff_overloads.h.
-//
-// TODO(soonho-tri): Next time when we upgrade Eigen, please check if we still
-// need these specializations.
-template <>
-class AutoDiffScalar<VectorXd>
- : public internal::auto_diff_special_op<VectorXd, false> {
- public:
- typedef internal::auto_diff_special_op<VectorXd, false> Base;
- typedef typename internal::remove_all<VectorXd>::type DerType;
- typedef typename internal::traits<DerType>::Scalar Scalar;
- typedef typename NumTraits<Scalar>::Real Real;
-
- using Base::operator+;
- using Base::operator*;
-
- AutoDiffScalar() {}
-
- AutoDiffScalar(const Scalar& value, int nbDer, int derNumber)
- : m_value(value), m_derivatives(DerType::Zero(nbDer)) {
- m_derivatives.coeffRef(derNumber) = Scalar(1);
- }
-
- // NOLINTNEXTLINE(runtime/explicit): Code from Eigen.
- AutoDiffScalar(const Real& value) : m_value(value) {
- if (m_derivatives.size() > 0) m_derivatives.setZero();
- }
-
- AutoDiffScalar(const Scalar& value, const DerType& der)
- : m_value(value), m_derivatives(der) {}
-
- template <typename OtherDerType>
- AutoDiffScalar(
- const AutoDiffScalar<OtherDerType>& other
-#ifndef EIGEN_PARSED_BY_DOXYGEN
- ,
- typename internal::enable_if<
- internal::is_same<
- Scalar, typename internal::traits<typename internal::remove_all<
- OtherDerType>::type>::Scalar>::value,
- void*>::type = 0
-#endif
- )
- : m_value(other.value()), m_derivatives(other.derivatives()) {
- }
-
- friend std::ostream& operator<<(std::ostream& s, const AutoDiffScalar& a) {
- return s << a.value();
- }
-
- AutoDiffScalar(const AutoDiffScalar& other)
- : m_value(other.value()), m_derivatives(other.derivatives()) {}
-
- template <typename OtherDerType>
- inline AutoDiffScalar& operator=(const AutoDiffScalar<OtherDerType>& other) {
- m_value = other.value();
- m_derivatives = other.derivatives();
- return *this;
- }
-
- inline AutoDiffScalar& operator=(const AutoDiffScalar& other) {
- m_value = other.value();
- m_derivatives = other.derivatives();
- return *this;
- }
-
- inline AutoDiffScalar& operator=(const Scalar& other) {
- m_value = other;
- if (m_derivatives.size() > 0) m_derivatives.setZero();
- return *this;
- }
-
- inline const Scalar& value() const { return m_value; }
- inline Scalar& value() { return m_value; }
-
- inline const DerType& derivatives() const { return m_derivatives; }
- inline DerType& derivatives() { return m_derivatives; }
-
- inline bool operator<(const Scalar& other) const { return m_value < other; }
- inline bool operator<=(const Scalar& other) const { return m_value <= other; }
- inline bool operator>(const Scalar& other) const { return m_value > other; }
- inline bool operator>=(const Scalar& other) const { return m_value >= other; }
- inline bool operator==(const Scalar& other) const { return m_value == other; }
- inline bool operator!=(const Scalar& other) const { return m_value != other; }
-
- friend inline bool operator<(const Scalar& a, const AutoDiffScalar& b) {
- return a < b.value();
- }
- friend inline bool operator<=(const Scalar& a, const AutoDiffScalar& b) {
- return a <= b.value();
- }
- friend inline bool operator>(const Scalar& a, const AutoDiffScalar& b) {
- return a > b.value();
- }
- friend inline bool operator>=(const Scalar& a, const AutoDiffScalar& b) {
- return a >= b.value();
- }
- friend inline bool operator==(const Scalar& a, const AutoDiffScalar& b) {
- return a == b.value();
- }
- friend inline bool operator!=(const Scalar& a, const AutoDiffScalar& b) {
- return a != b.value();
- }
-
- template <typename OtherDerType>
- inline bool operator<(const AutoDiffScalar<OtherDerType>& b) const {
- return m_value < b.value();
- }
- template <typename OtherDerType>
- inline bool operator<=(const AutoDiffScalar<OtherDerType>& b) const {
- return m_value <= b.value();
- }
- template <typename OtherDerType>
- inline bool operator>(const AutoDiffScalar<OtherDerType>& b) const {
- return m_value > b.value();
- }
- template <typename OtherDerType>
- inline bool operator>=(const AutoDiffScalar<OtherDerType>& b) const {
- return m_value >= b.value();
- }
- template <typename OtherDerType>
- inline bool operator==(const AutoDiffScalar<OtherDerType>& b) const {
- return m_value == b.value();
- }
- template <typename OtherDerType>
- inline bool operator!=(const AutoDiffScalar<OtherDerType>& b) const {
- return m_value != b.value();
- }
-
- inline const AutoDiffScalar<DerType> operator+(const Scalar& other) const {
- return AutoDiffScalar<DerType>(m_value + other, m_derivatives);
- }
-
- friend inline const AutoDiffScalar<DerType> operator+(
- const Scalar& a, const AutoDiffScalar& b) {
- return AutoDiffScalar<DerType>(a + b.value(), b.derivatives());
- }
-
- inline AutoDiffScalar& operator+=(const Scalar& other) {
- value() += other;
- return *this;
- }
-
- template <typename OtherDerType>
- inline const AutoDiffScalar<DerType> operator+(
- const AutoDiffScalar<OtherDerType>& other) const {
- const bool has_this_der = m_derivatives.size() > 0;
- const bool has_both_der = has_this_der && (other.derivatives().size() > 0);
- return MakeAutoDiffScalar(
- m_value + other.value(),
- has_both_der
- ? VectorXd(m_derivatives + other.derivatives())
- : has_this_der ? m_derivatives : VectorXd(other.derivatives()));
- }
-
- template <typename OtherDerType>
- inline AutoDiffScalar& operator+=(const AutoDiffScalar<OtherDerType>& other) {
- (*this) = (*this) + other;
- return *this;
- }
-
- inline const AutoDiffScalar<DerType> operator-(const Scalar& b) const {
- return AutoDiffScalar<DerType>(m_value - b, m_derivatives);
- }
-
- friend inline const AutoDiffScalar<DerType> operator-(
- const Scalar& a, const AutoDiffScalar& b) {
- return AutoDiffScalar<DerType>(a - b.value(), -b.derivatives());
- }
-
- inline AutoDiffScalar& operator-=(const Scalar& other) {
- value() -= other;
- return *this;
- }
-
- template <typename OtherDerType>
- inline const AutoDiffScalar<DerType> operator-(
- const AutoDiffScalar<OtherDerType>& other) const {
- const bool has_this_der = m_derivatives.size() > 0;
- const bool has_both_der = has_this_der && (other.derivatives().size() > 0);
- return MakeAutoDiffScalar(
- m_value - other.value(),
- has_both_der
- ? VectorXd(m_derivatives - other.derivatives())
- : has_this_der ? m_derivatives : VectorXd(-other.derivatives()));
- }
-
- template <typename OtherDerType>
- inline AutoDiffScalar& operator-=(const AutoDiffScalar<OtherDerType>& other) {
- *this = *this - other;
- return *this;
- }
-
- inline const AutoDiffScalar<DerType> operator-() const {
- return AutoDiffScalar<DerType>(-m_value, -m_derivatives);
- }
-
- inline const AutoDiffScalar<DerType> operator*(const Scalar& other) const {
- return MakeAutoDiffScalar(m_value * other, m_derivatives * other);
- }
-
- friend inline const AutoDiffScalar<DerType> operator*(
- const Scalar& other, const AutoDiffScalar& a) {
- return MakeAutoDiffScalar(a.value() * other, a.derivatives() * other);
- }
-
- inline const AutoDiffScalar<DerType> operator/(const Scalar& other) const {
- return MakeAutoDiffScalar(m_value / other,
- (m_derivatives * (Scalar(1) / other)));
- }
-
- friend inline const AutoDiffScalar<DerType> operator/(
- const Scalar& other, const AutoDiffScalar& a) {
- return MakeAutoDiffScalar(
- other / a.value(),
- a.derivatives() * (Scalar(-other) / (a.value() * a.value())));
- }
-
- template <typename OtherDerType>
- inline const AutoDiffScalar<DerType> operator/(
- const AutoDiffScalar<OtherDerType>& other) const {
- const auto& this_der = m_derivatives;
- const auto& other_der = other.derivatives();
- const bool has_this_der = m_derivatives.size() > 0;
- const bool has_both_der = has_this_der && (other.derivatives().size() > 0);
- const double scale = 1. / (other.value() * other.value());
- return MakeAutoDiffScalar(
- m_value / other.value(),
- has_both_der ?
- VectorXd(this_der * other.value() - other_der * m_value) * scale :
- has_this_der ?
- VectorXd(this_der * other.value()) * scale :
- // has_other_der || has_neither
- VectorXd(other_der * -m_value) * scale);
- }
-
- template <typename OtherDerType>
- inline const AutoDiffScalar<DerType> operator*(
- const AutoDiffScalar<OtherDerType>& other) const {
- const bool has_this_der = m_derivatives.size() > 0;
- const bool has_both_der = has_this_der && (other.derivatives().size() > 0);
- return MakeAutoDiffScalar(
- m_value * other.value(),
- has_both_der ? VectorXd(m_derivatives * other.value() +
- other.derivatives() * m_value)
- : has_this_der ? VectorXd(m_derivatives * other.value())
- : VectorXd(other.derivatives() * m_value));
- }
-
- inline AutoDiffScalar& operator*=(const Scalar& other) {
- *this = *this * other;
- return *this;
- }
-
- template <typename OtherDerType>
- inline AutoDiffScalar& operator*=(const AutoDiffScalar<OtherDerType>& other) {
- *this = *this * other;
- return *this;
- }
-
- inline AutoDiffScalar& operator/=(const Scalar& other) {
- *this = *this / other;
- return *this;
- }
-
- template <typename OtherDerType>
- inline AutoDiffScalar& operator/=(const AutoDiffScalar<OtherDerType>& other) {
- *this = *this / other;
- return *this;
- }
-
- protected:
- Scalar m_value;
- DerType m_derivatives;
-};
-
-#define DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(FUNC, CODE) \
- inline const AutoDiffScalar<VectorXd> FUNC( \
- const AutoDiffScalar<VectorXd>& x) { \
- EIGEN_UNUSED typedef double Scalar; \
- CODE; \
- }
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- abs, using std::abs; return Eigen::MakeAutoDiffScalar(
- abs(x.value()), x.derivatives() * (x.value() < 0 ? -1 : 1));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- abs2, using numext::abs2; return Eigen::MakeAutoDiffScalar(
- abs2(x.value()), x.derivatives() * (Scalar(2) * x.value()));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- sqrt, using std::sqrt; Scalar sqrtx = sqrt(x.value());
- return Eigen::MakeAutoDiffScalar(sqrtx,
- x.derivatives() * (Scalar(0.5) / sqrtx));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- cos, using std::cos; using std::sin;
- return Eigen::MakeAutoDiffScalar(cos(x.value()),
- x.derivatives() * (-sin(x.value())));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- sin, using std::sin; using std::cos;
- return Eigen::MakeAutoDiffScalar(sin(x.value()),
- x.derivatives() * cos(x.value()));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- exp, using std::exp; Scalar expx = exp(x.value());
- return Eigen::MakeAutoDiffScalar(expx, x.derivatives() * expx);)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- log, using std::log; return Eigen::MakeAutoDiffScalar(
- log(x.value()), x.derivatives() * (Scalar(1) / x.value()));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- tan, using std::tan; using std::cos; return Eigen::MakeAutoDiffScalar(
- tan(x.value()),
- x.derivatives() * (Scalar(1) / numext::abs2(cos(x.value()))));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- asin, using std::sqrt; using std::asin; return Eigen::MakeAutoDiffScalar(
- asin(x.value()),
- x.derivatives() * (Scalar(1) / sqrt(1 - numext::abs2(x.value()))));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- acos, using std::sqrt; using std::acos; return Eigen::MakeAutoDiffScalar(
- acos(x.value()),
- x.derivatives() * (Scalar(-1) / sqrt(1 - numext::abs2(x.value()))));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- atan, using std::atan; return Eigen::MakeAutoDiffScalar(
- atan(x.value()),
- x.derivatives() * (Scalar(1) / (1 + x.value() * x.value())));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- tanh, using std::cosh; using std::tanh; return Eigen::MakeAutoDiffScalar(
- tanh(x.value()),
- x.derivatives() * (Scalar(1) / numext::abs2(cosh(x.value()))));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- sinh, using std::sinh; using std::cosh;
- return Eigen::MakeAutoDiffScalar(sinh(x.value()),
- x.derivatives() * cosh(x.value()));)
-
-DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY(
- cosh, using std::sinh; using std::cosh;
- return Eigen::MakeAutoDiffScalar(cosh(x.value()),
- x.derivatives() * sinh(x.value()));)
-
-#undef DRAKE_EIGEN_AUTODIFFXD_DECLARE_GLOBAL_UNARY
-
-// We have this specialization here because the Eigen-3.3.3's atan2
-// implementation for AutoDiffScalar does not make a return with properly sized
-// derivatives.
-inline const AutoDiffScalar<VectorXd> atan2(const AutoDiffScalar<VectorXd>& a,
- const AutoDiffScalar<VectorXd>& b) {
- const bool has_a_der = a.derivatives().size() > 0;
- const bool has_both_der = has_a_der && (b.derivatives().size() > 0);
- const double squared_hypot = a.value() * a.value() + b.value() * b.value();
- return MakeAutoDiffScalar(
- std::atan2(a.value(), b.value()),
- VectorXd((has_both_der
- ? VectorXd(a.derivatives() * b.value() -
- a.value() * b.derivatives())
- : has_a_der ? VectorXd(a.derivatives() * b.value())
- : VectorXd(-a.value() * b.derivatives())) /
- squared_hypot));
-}
-
-inline const AutoDiffScalar<VectorXd> pow(const AutoDiffScalar<VectorXd>& a,
- double b) {
- using std::pow;
- return MakeAutoDiffScalar(pow(a.value(), b),
- a.derivatives() * (b * pow(a.value(), b - 1)));
-}
-
-#endif
-
-} // namespace Eigen
diff --git a/wpimath/src/test/native/include/drake/common/cond.h b/wpimath/src/test/native/include/drake/common/cond.h
deleted file mode 100644
index 16dd21e..0000000
--- a/wpimath/src/test/native/include/drake/common/cond.h
+++ /dev/null
@@ -1,44 +0,0 @@
-#pragma once
-
-#include <functional>
-#include <type_traits>
-
-#include "drake/common/double_overloads.h"
-
-namespace drake {
-/** @name cond
- Constructs conditional expression (similar to Lisp's cond).
-
- @verbatim
- cond(cond_1, expr_1,
- cond_2, expr_2,
- ..., ...,
- cond_n, expr_n,
- expr_{n+1})
- @endverbatim
-
- The value returned by the above cond expression is @c expr_1 if @c cond_1 is
- true; else if @c cond_2 is true then @c expr_2; ... ; else if @c cond_n is
- true then @c expr_n. If none of the conditions are true, it returns @c
- expr_{n+1}.
-
- @note This functions assumes that @p ScalarType provides @c operator< and the
- type of @c f_cond is the type of the return type of <tt>operator<(ScalarType,
- ScalarType)</tt>. For example, @c symbolic::Expression can be used as a @p
- ScalarType because it provides <tt>symbolic::Formula
- operator<(symbolic::Expression, symbolic::Expression)</tt>.
-
-
- @{
- */
-template <typename ScalarType>
-ScalarType cond(const ScalarType& e) {
- return e;
-}
-template <typename ScalarType, typename... Rest>
-ScalarType cond(const decltype(ScalarType() < ScalarType()) & f_cond,
- const ScalarType& e_then, Rest... rest) {
- return if_then_else(f_cond, e_then, cond(rest...));
-}
-///@}
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/constants.h b/wpimath/src/test/native/include/drake/common/constants.h
deleted file mode 100644
index 0ccddca..0000000
--- a/wpimath/src/test/native/include/drake/common/constants.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#pragma once
-
-namespace drake {
-
-constexpr int kQuaternionSize = 4;
-
-constexpr int kSpaceDimension = 3;
-
-constexpr int kRpySize = 3;
-
-/// https://en.wikipedia.org/wiki/Screw_theory#Twist
-constexpr int kTwistSize = 6;
-
-/// http://www.euclideanspace.com/maths/geometry/affine/matrix4x4/
-constexpr int kHomogeneousTransformSize = 16;
-
-const int kRotmatSize = kSpaceDimension * kSpaceDimension;
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/double_overloads.h b/wpimath/src/test/native/include/drake/common/double_overloads.h
deleted file mode 100644
index 125f113..0000000
--- a/wpimath/src/test/native/include/drake/common/double_overloads.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/// @file
-/// Provides necessary operations on double to have it as a ScalarType in drake.
-
-#pragma once
-
-namespace drake {
-/// Provides if-then-else expression for double. The value returned by the
-/// if-then-else expression is @p v_then if @p f_cond is @c true. Otherwise, it
-/// returns @p v_else.
-
-/// The semantics is similar but not exactly the same as C++'s conditional
-/// expression constructed by its ternary operator, @c ?:. In
-/// <tt>if_then_else(f_cond, v_then, v_else)</tt>, both of @p v_then and @p
-/// v_else are evaluated regardless of the evaluation of @p f_cond. In contrast,
-/// only one of @p v_then or @p v_else is evaluated in C++'s conditional
-/// expression <tt>f_cond ? v_then : v_else</tt>.
-inline double if_then_else(bool f_cond, double v_then, double v_else) {
- return f_cond ? v_then : v_else;
-}
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/drake_deprecated.h b/wpimath/src/test/native/include/drake/common/drake_deprecated.h
deleted file mode 100644
index 5ce6328..0000000
--- a/wpimath/src/test/native/include/drake/common/drake_deprecated.h
+++ /dev/null
@@ -1,65 +0,0 @@
-#pragma once
-
-/** @file
-Provides a portable macro for use in generating compile-time warnings for
-use of code that is permitted but discouraged. */
-
-#ifdef DRAKE_DOXYGEN_CXX
-/** Use `DRAKE_DEPRECATED("removal_date", "message")` to discourage use of
-certain APIs. It can be used on classes, typedefs, variables, non-static data
-members, functions, arguments, enumerations, and template specializations. When
-code refers to the deprecated item, a compile time warning will be issued
-displaying the given message, preceded by "DRAKE DEPRECATED: ". The Doxygen API
-reference will show that the API is deprecated, along with the message.
-
-This is typically used for constructs that have been replaced by something
-better and it is good practice to suggest the appropriate replacement in the
-deprecation message. Deprecation warnings are conventionally used to convey to
-users that a feature they are depending on will be removed in a future release.
-
-Every deprecation notice must include a date (as YYYY-MM-DD string) where the
-deprecated item is planned for removal. Future commits may change the date
-(e.g., delaying the removal) but should generally always reflect the best
-current expectation for removal.
-
-Absent any other particular need, we prefer to use a deprecation period of
-three months by default, often rounded up to the next first of the month. So
-for code announced as deprecated on 2018-01-22 the removal_date would nominally
-be set to 2018-05-01.
-
-Try to keep the date string immediately after the DRAKE_DEPRECATED macro name,
-even if the message itself must be wrapped to a new line:
-@code
- DRAKE_DEPRECATED("2038-01-19",
- "foo is being replaced with a safer, const-method named bar().")
- int foo();
-@endcode
-
-Sample uses: @code
- // Attribute comes *before* declaration of a deprecated function or variable;
- // no semicolon is allowed.
- DRAKE_DEPRECATED("2038-01-19", "f() is slow; use g() instead.")
- int f(int arg);
-
- // Attribute comes *after* struct, class, enum keyword.
- class DRAKE_DEPRECATED("2038-01-19", "Use MyNewClass instead.")
- MyClass {
- };
-
- // Type alias goes before the '='.
- using NewType
- DRAKE_DEPRECATED("2038-01-19", "Use NewType instead.")
- = OldType;
-@endcode
-*/
-#define DRAKE_DEPRECATED(removal_date, message)
-
-#else // DRAKE_DOXYGEN_CXX
-
-#define DRAKE_DEPRECATED(removal_date, message) \
- [[deprecated( \
- "\nDRAKE DEPRECATED: " message \
- "\nThe deprecated code will be removed from Drake" \
- " on or after " removal_date ".")]]
-
-#endif // DRAKE_DOXYGEN_CXX
diff --git a/wpimath/src/test/native/include/drake/common/drake_nodiscard.h b/wpimath/src/test/native/include/drake/common/drake_nodiscard.h
deleted file mode 100644
index 29f078d..0000000
--- a/wpimath/src/test/native/include/drake/common/drake_nodiscard.h
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-// TODO(jwnimmer-tri) Once we are in --std=c++17 mode as our minimum version,
-// we can remove this file and just say [[nodiscard]] directly everywhere.
-
-#if defined(DRAKE_DOXYGEN_CXX) || __has_cpp_attribute(nodiscard)
-/** Synonym for [[nodiscard]], iff the current compiler supports it;
-see https://en.cppreference.com/w/cpp/language/attributes/nodiscard. */
-// NOLINTNEXTLINE(whitespace/braces)
-#define DRAKE_NODISCARD [[nodiscard]]
-#else
-#define DRAKE_NODISCARD
-#endif
diff --git a/wpimath/src/test/native/include/drake/common/dummy_value.h b/wpimath/src/test/native/include/drake/common/dummy_value.h
deleted file mode 100644
index b9c616a..0000000
--- a/wpimath/src/test/native/include/drake/common/dummy_value.h
+++ /dev/null
@@ -1,35 +0,0 @@
-#pragma once
-
-#include <limits>
-
-namespace drake {
-
-/// Provides a "dummy" value for a ScalarType -- a value that is unlikely to be
-/// mistaken for a purposefully-computed value, useful for initializing a value
-/// before the true result is available.
-///
-/// Defaults to using std::numeric_limits::quiet_NaN when available; it is a
-/// compile-time error to call the unspecialized dummy_value::get() when
-/// quiet_NaN is unavailable.
-///
-/// See autodiff_overloads.h to use this with Eigen's AutoDiffScalar.
-template <typename T>
-struct dummy_value {
- static constexpr T get() {
- static_assert(std::numeric_limits<T>::has_quiet_NaN,
- "Custom scalar types should specialize this struct");
- return std::numeric_limits<T>::quiet_NaN();
- }
-};
-
-template <>
-struct dummy_value<int> {
- static constexpr int get() {
- // D is for "Dummy". We assume as least 32 bits (per cppguide) -- if `int`
- // is larger than 32 bits, this will leave some fraction of the bytes zero
- // instead of 0xDD, but that's okay.
- return 0xDDDDDDDD;
- }
-};
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/eigen_autodiff_limits.h b/wpimath/src/test/native/include/drake/common/eigen_autodiff_limits.h
deleted file mode 100644
index 49175ce..0000000
--- a/wpimath/src/test/native/include/drake/common/eigen_autodiff_limits.h
+++ /dev/null
@@ -1,20 +0,0 @@
-#pragma once
-
-#ifndef DRAKE_COMMON_AUTODIFF_HEADER
-// TODO(soonho-tri): Change to #error.
-#warning Do not directly include this file. Include "drake/common/autodiff.h".
-#endif
-
-#include <limits>
-
-// Eigen provides `numeric_limits<AutoDiffScalar<T>>` starting with v3.3.4.
-#if !EIGEN_VERSION_AT_LEAST(3, 3, 4) // Eigen Version < v3.3.4
-
-namespace std {
-template <typename T>
-class numeric_limits<Eigen::AutoDiffScalar<T>>
- : public numeric_limits<typename T::Scalar> {};
-
-} // namespace std
-
-#endif // Eigen Version < v3.3.4
diff --git a/wpimath/src/test/native/include/drake/common/eigen_autodiff_types.h b/wpimath/src/test/native/include/drake/common/eigen_autodiff_types.h
deleted file mode 100644
index 10ffec6..0000000
--- a/wpimath/src/test/native/include/drake/common/eigen_autodiff_types.h
+++ /dev/null
@@ -1,38 +0,0 @@
-#pragma once
-
-/// @file
-/// This file contains abbreviated definitions for certain uses of
-/// AutoDiffScalar that are commonly used in Drake.
-/// @see also eigen_types.h
-
-#ifndef DRAKE_COMMON_AUTODIFF_HEADER
-// TODO(soonho-tri): Change to #error.
-#warning Do not directly include this file. Include "drake/common/autodiff.h".
-#endif
-
-#include <type_traits>
-
-#include <Eigen/Core>
-
-#include "drake/common/eigen_types.h"
-
-namespace drake {
-
-/// An autodiff variable with a dynamic number of partials.
-using AutoDiffXd = Eigen::AutoDiffScalar<Eigen::VectorXd>;
-
-// TODO(hongkai-dai): Recursive template to get arbitrary gradient order.
-
-/// An autodiff variable with `num_vars` partials.
-template <int num_vars>
-using AutoDiffd = Eigen::AutoDiffScalar<Eigen::Matrix<double, num_vars, 1> >;
-
-/// A vector of `rows` autodiff variables, each with `num_vars` partials.
-template <int num_vars, int rows>
-using AutoDiffVecd = Eigen::Matrix<AutoDiffd<num_vars>, rows, 1>;
-
-/// A dynamic-sized vector of autodiff variables, each with a dynamic-sized
-/// vector of partials.
-typedef AutoDiffVecd<Eigen::Dynamic, Eigen::Dynamic> AutoDiffVecXd;
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/eigen_types.h b/wpimath/src/test/native/include/drake/common/eigen_types.h
deleted file mode 100644
index abe3e0b..0000000
--- a/wpimath/src/test/native/include/drake/common/eigen_types.h
+++ /dev/null
@@ -1,461 +0,0 @@
-#pragma once
-
-/// @file
-/// This file contains abbreviated definitions for certain specializations of
-/// Eigen::Matrix that are commonly used in Drake.
-/// These convenient definitions are templated on the scalar type of the Eigen
-/// object. While Drake uses `<T>` for scalar types across the entire code base
-/// we decided in this file to use `<Scalar>` to be more consistent with the
-/// usage of `<Scalar>` in Eigen's code base.
-/// @see also eigen_autodiff_types.h
-
-#include <utility>
-
-#include <Eigen/Core>
-
-#include "drake/common/constants.h"
-#include "drake/common/drake_assert.h"
-#include "drake/common/drake_copyable.h"
-#include "drake/common/drake_deprecated.h"
-
-namespace drake {
-
-/// The empty column vector (zero rows, one column), templated on scalar type.
-template <typename Scalar>
-using Vector0 = Eigen::Matrix<Scalar, 0, 1>;
-
-/// A column vector of size 1 (that is, a scalar), templated on scalar type.
-template <typename Scalar>
-using Vector1 = Eigen::Matrix<Scalar, 1, 1>;
-
-/// A column vector of size 1 of doubles.
-using Vector1d = Eigen::Matrix<double, 1, 1>;
-
-/// A column vector of size 2, templated on scalar type.
-template <typename Scalar>
-using Vector2 = Eigen::Matrix<Scalar, 2, 1>;
-
-/// A column vector of size 3, templated on scalar type.
-template <typename Scalar>
-using Vector3 = Eigen::Matrix<Scalar, 3, 1>;
-
-/// A column vector of size 4, templated on scalar type.
-template <typename Scalar>
-using Vector4 = Eigen::Matrix<Scalar, 4, 1>;
-
-/// A column vector of size 6.
-template <typename Scalar>
-using Vector6 = Eigen::Matrix<Scalar, 6, 1>;
-
-/// A column vector templated on the number of rows.
-template <typename Scalar, int Rows>
-using Vector = Eigen::Matrix<Scalar, Rows, 1>;
-
-/// A column vector of any size, templated on scalar type.
-template <typename Scalar>
-using VectorX = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
-
-/// A vector of dynamic size templated on scalar type, up to a maximum of 6
-/// elements.
-template <typename Scalar>
-using VectorUpTo6 = Eigen::Matrix<Scalar, Eigen::Dynamic, 1, 0, 6, 1>;
-
-/// A row vector of size 2, templated on scalar type.
-template <typename Scalar>
-using RowVector2 = Eigen::Matrix<Scalar, 1, 2>;
-
-/// A row vector of size 3, templated on scalar type.
-template <typename Scalar>
-using RowVector3 = Eigen::Matrix<Scalar, 1, 3>;
-
-/// A row vector of size 4, templated on scalar type.
-template <typename Scalar>
-using RowVector4 = Eigen::Matrix<Scalar, 1, 4>;
-
-/// A row vector of size 6.
-template <typename Scalar>
-using RowVector6 = Eigen::Matrix<Scalar, 1, 6>;
-
-/// A row vector templated on the number of columns.
-template <typename Scalar, int Cols>
-using RowVector = Eigen::Matrix<Scalar, 1, Cols>;
-
-/// A row vector of any size, templated on scalar type.
-template <typename Scalar>
-using RowVectorX = Eigen::Matrix<Scalar, 1, Eigen::Dynamic>;
-
-
-/// A matrix of 2 rows and 2 columns, templated on scalar type.
-template <typename Scalar>
-using Matrix2 = Eigen::Matrix<Scalar, 2, 2>;
-
-/// A matrix of 3 rows and 3 columns, templated on scalar type.
-template <typename Scalar>
-using Matrix3 = Eigen::Matrix<Scalar, 3, 3>;
-
-/// A matrix of 4 rows and 4 columns, templated on scalar type.
-template <typename Scalar>
-using Matrix4 = Eigen::Matrix<Scalar, 4, 4>;
-
-/// A matrix of 6 rows and 6 columns, templated on scalar type.
-template <typename Scalar>
-using Matrix6 = Eigen::Matrix<Scalar, 6, 6>;
-
-/// A matrix of 2 rows, dynamic columns, templated on scalar type.
-template <typename Scalar>
-using Matrix2X = Eigen::Matrix<Scalar, 2, Eigen::Dynamic>;
-
-/// A matrix of 3 rows, dynamic columns, templated on scalar type.
-template <typename Scalar>
-using Matrix3X = Eigen::Matrix<Scalar, 3, Eigen::Dynamic>;
-
-/// A matrix of 4 rows, dynamic columns, templated on scalar type.
-template <typename Scalar>
-using Matrix4X = Eigen::Matrix<Scalar, 4, Eigen::Dynamic>;
-
-/// A matrix of 6 rows, dynamic columns, templated on scalar type.
-template <typename Scalar>
-using Matrix6X = Eigen::Matrix<Scalar, 6, Eigen::Dynamic>;
-
-/// A matrix of dynamic size, templated on scalar type.
-template <typename Scalar>
-using MatrixX = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>;
-
-/// A matrix of dynamic size templated on scalar type, up to a maximum of 6 rows
-/// and 6 columns. Rectangular matrices, with different number of rows and
-/// columns, are allowed.
-template <typename Scalar>
-using MatrixUpTo6 =
-Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, 0, 6, 6>;
-
-/// A quaternion templated on scalar type.
-template <typename Scalar>
-using Quaternion = Eigen::Quaternion<Scalar>;
-
-/// An AngleAxis templated on scalar type.
-template <typename Scalar>
-using AngleAxis = Eigen::AngleAxis<Scalar>;
-
-/// An Isometry templated on scalar type.
-template <typename Scalar>
-using Isometry3 = Eigen::Transform<Scalar, 3, Eigen::Isometry>;
-
-/// A translation in 3D templated on scalar type.
-template <typename Scalar>
-using Translation3 = Eigen::Translation<Scalar, 3>;
-
-/// A column vector consisting of one twist.
-template <typename Scalar>
-using TwistVector = Eigen::Matrix<Scalar, kTwistSize, 1>;
-
-/// A matrix with one twist per column, and dynamically many columns.
-template <typename Scalar>
-using TwistMatrix = Eigen::Matrix<Scalar, kTwistSize, Eigen::Dynamic>;
-
-/// A six-by-six matrix.
-template <typename Scalar>
-using SquareTwistMatrix = Eigen::Matrix<Scalar, kTwistSize, kTwistSize>;
-
-/// A column vector consisting of one wrench (spatial force) = `[r X f; f]`,
-/// where f is a force (translational force) applied at a point `P` and `r` is
-/// the position vector from a point `O` (called the "moment center") to point
-/// `P`.
-template <typename Scalar>
-using WrenchVector = Eigen::Matrix<Scalar, 6, 1>;
-
-/// A column vector consisting of a concatenated rotational and translational
-/// force. The wrench is a special case of a SpatialForce. For a general
-/// SpatialForce the rotational force can be a pure torque or the accumulation
-/// of moments and need not necessarily be a function of the force term.
-template <typename Scalar>
-using SpatialForce
-DRAKE_DEPRECATED("2019-10-15", "Please use Vector6<> instead.")
- = Eigen::Matrix<Scalar, 6, 1>;
-
-/// EigenSizeMinPreferDynamic<a, b>::value gives the min between compile-time
-/// sizes @p a and @p b. 0 has absolute priority, followed by 1, followed by
-/// Dynamic, followed by other finite values.
-///
-/// Note that this is a type-trait version of EIGEN_SIZE_MIN_PREFER_DYNAMIC
-/// macro in "Eigen/Core/util/Macros.h".
-template <int a, int b>
-struct EigenSizeMinPreferDynamic {
- // clang-format off
- static constexpr int value = (a == 0 || b == 0) ? 0 :
- (a == 1 || b == 1) ? 1 :
- (a == Eigen::Dynamic || b == Eigen::Dynamic) ? Eigen::Dynamic :
- a <= b ? a : b;
- // clang-format on
-};
-
-/// EigenSizeMinPreferFixed is a variant of EigenSizeMinPreferDynamic. The
-/// difference is that finite values now have priority over Dynamic, so that
-/// EigenSizeMinPreferFixed<3, Dynamic>::value gives 3.
-///
-/// Note that this is a type-trait version of EIGEN_SIZE_MIN_PREFER_FIXED macro
-/// in "Eigen/Core/util/Macros.h".
-template <int a, int b>
-struct EigenSizeMinPreferFixed {
- // clang-format off
- static constexpr int value = (a == 0 || b == 0) ? 0 :
- (a == 1 || b == 1) ? 1 :
- (a == Eigen::Dynamic && b == Eigen::Dynamic) ? Eigen::Dynamic :
- (a == Eigen::Dynamic) ? b :
- (b == Eigen::Dynamic) ? a :
- a <= b ? a : b;
- // clang-format on
-};
-
-/// MultiplyEigenSizes<a, b> gives a * b if both of a and b are fixed
-/// sizes. Otherwise it gives Eigen::Dynamic.
-template <int a, int b>
-struct MultiplyEigenSizes {
- static constexpr int value =
- (a == Eigen::Dynamic || b == Eigen::Dynamic) ? Eigen::Dynamic : a * b;
-};
-
-/*
- * Determines if a type is derived from EigenBase<> (e.g. ArrayBase<>,
- * MatrixBase<>).
- */
-template <typename Derived>
-struct is_eigen_type : std::is_base_of<Eigen::EigenBase<Derived>, Derived> {};
-
-/*
- * Determines if an EigenBase<> has a specific scalar type.
- */
-template <typename Derived, typename Scalar>
-struct is_eigen_scalar_same
- : std::integral_constant<
- bool, is_eigen_type<Derived>::value &&
- std::is_same<typename Derived::Scalar, Scalar>::value> {};
-
-/*
- * Determines if an EigenBase<> type is a compile-time (column) vector.
- * This will not check for run-time size.
- */
-template <typename Derived>
-struct is_eigen_vector
- : std::integral_constant<bool, is_eigen_type<Derived>::value &&
- Derived::ColsAtCompileTime == 1> {};
-
-/*
- * Determines if an EigenBase<> type is a compile-time (column) vector of a
- * scalar type. This will not check for run-time size.
- */
-template <typename Derived, typename Scalar>
-struct is_eigen_vector_of
- : std::integral_constant<
- bool, is_eigen_scalar_same<Derived, Scalar>::value &&
- is_eigen_vector<Derived>::value> {};
-
-// TODO(eric.cousineau): A 1x1 matrix will be disqualified in this case, and
-// this logic will qualify it as a vector. Address the downstream logic if this
-// becomes an issue.
-/*
- * Determines if a EigenBase<> type is a compile-time non-column-vector matrix
- * of a scalar type. This will not check for run-time size.
- * @note For an EigenBase<> of the correct Scalar type, this logic is
- * exclusive to is_eigen_vector_of<> such that distinct specializations are not
- * ambiguous.
- */
-template <typename Derived, typename Scalar>
-struct is_eigen_nonvector_of
- : std::integral_constant<
- bool, is_eigen_scalar_same<Derived, Scalar>::value &&
- !is_eigen_vector<Derived>::value> {};
-
-// TODO(eric.cousineau): Add alias is_eigen_matrix_of = is_eigen_scalar_same if
-// appropriate.
-
-/// This wrapper class provides a way to write non-template functions taking raw
-/// pointers to Eigen objects as parameters while limiting the number of copies,
-/// similar to `Eigen::Ref`. Internally, it keeps an instance of `Eigen::Ref<T>`
-/// and provides access to it via `operator*` and `operator->`.
-///
-/// The motivation of this class is to follow <a
-/// href="https://google.github.io/styleguide/cppguide.html#Reference_Arguments">GSG's
-/// "output arguments should be pointers" rule</a> while taking advantage of
-/// using `Eigen::Ref`. Here is an example.
-///
-/// @code
-/// // This function is taking an Eigen::Ref of a matrix and modifies it in
-/// // the body. This violates GSG's rule on output parameters.
-/// void foo(Eigen::Ref<Eigen::MatrixXd> M) {
-/// M(0, 0) = 0;
-/// }
-/// // At Call-site, we have:
-/// foo(M);
-/// foo(M.block(0, 0, 2, 2));
-///
-/// // We can rewrite the above function into the following using EigenPtr.
-/// void foo(EigenPtr<Eigen::MatrixXd> M) {
-/// (*M)(0, 0) = 0;
-/// }
-/// // Note that, call sites should be changed to:
-/// foo(&M);
-///
-/// // We need tmp to avoid taking the address of a temporary object such as the
-/// // return value of .block().
-/// auto tmp = M.block(0, 0, 2, 2);
-/// foo(&tmp);
-/// @endcode
-///
-/// Notice that methods taking an EigenPtr can mutate the entries of a matrix as
-/// in method `foo()` in the example code above, but cannot change its size.
-/// This is because `operator*` and `operator->` return an `Eigen::Ref<T>`
-/// object and only plain matrices/arrays can be resized and not expressions.
-/// This **is** the desired behavior, since resizing the block of a matrix or
-/// even a more general expression should not be allowed. If you do want to be
-/// able to resize a mutable matrix argument, then you must pass it as a
-/// `Matrix<T>*`, like so:
-/// @code
-/// void bar(Eigen::MatrixXd* M) {
-/// DRAKE_THROW_UNLESS(M != nullptr);
-/// // In this case this method only works with 4x3 matrices.
-/// if (M->rows() != 4 && M->cols() != 3) {
-/// M->resize(4, 3);
-/// }
-/// (*M)(0, 0) = 0;
-/// }
-/// @endcode
-///
-/// @note This class provides a way to avoid the `const_cast` hack introduced in
-/// <a
-/// href="https://eigen.tuxfamily.org/dox/TopicFunctionTakingEigenTypes.html#TopicPlainFunctionsFailing">Eigen's
-/// documentation</a>.
-template <typename PlainObjectType>
-class EigenPtr {
- public:
- typedef Eigen::Ref<PlainObjectType> RefType;
-
- EigenPtr() : EigenPtr(nullptr) {}
-
- /// Overload for `nullptr`.
- // NOLINTNEXTLINE(runtime/explicit) This conversion is desirable.
- EigenPtr(std::nullptr_t) {}
-
- /// Constructs with a reference to the given matrix type.
- // NOLINTNEXTLINE(runtime/explicit) This conversion is desirable.
- EigenPtr(const EigenPtr& other) { assign(other); }
-
- /// Constructs with a reference to another matrix type.
- /// May be `nullptr`.
- template <typename PlainObjectTypeIn>
- // NOLINTNEXTLINE(runtime/explicit) This conversion is desirable.
- EigenPtr(PlainObjectTypeIn* m) {
- if (m) {
- m_.set_value(m);
- }
- }
-
- /// Constructs from another EigenPtr.
- template <typename PlainObjectTypeIn>
- // NOLINTNEXTLINE(runtime/explicit) This conversion is desirable.
- EigenPtr(const EigenPtr<PlainObjectTypeIn>& other) {
- // Cannot directly construct `m_` from `other.m_`.
- assign(other);
- }
-
- EigenPtr& operator=(const EigenPtr& other) {
- // We must explicitly override this version of operator=.
- // The template below will not take precedence over this one.
- return assign(other);
- }
-
- template <typename PlainObjectTypeIn>
- EigenPtr& operator=(const EigenPtr<PlainObjectTypeIn>& other) {
- return assign(other);
- }
-
- /// @throws std::runtime_error if this is a null dereference.
- RefType& operator*() const { return get_reference(); }
-
- /// @throws std::runtime_error if this is a null dereference.
- RefType* operator->() const { return &get_reference(); }
-
- /// Returns whether or not this contains a valid reference.
- operator bool() const { return is_valid(); }
-
- bool operator==(std::nullptr_t) const { return !is_valid(); }
-
- bool operator!=(std::nullptr_t) const { return is_valid(); }
-
- private:
- // Simple reassignable container without requirement of heap allocation.
- // This is used because `drake::optional<>` does not work with `Eigen::Ref<>`
- // because `Ref` deletes the necessary `operator=` overload for
- // `std::is_copy_assignable`.
- class ReassignableRef {
- public:
- DRAKE_NO_COPY_NO_MOVE_NO_ASSIGN(ReassignableRef)
- ReassignableRef() {}
- ~ReassignableRef() {
- reset();
- }
-
- // Reset value to null.
- void reset() {
- if (has_value_) {
- raw_value().~RefType();
- has_value_ = false;
- }
- }
-
- // Set value.
- template <typename PlainObjectTypeIn>
- void set_value(PlainObjectTypeIn* value_in) {
- if (has_value_) {
- raw_value().~RefType();
- }
- new (&raw_value()) RefType(*value_in);
- has_value_ = true;
- }
-
- // Access to value.
- RefType& value() {
- DRAKE_ASSERT(has_value());
- return raw_value();
- }
-
- // Indicates if it has a value.
- bool has_value() const { return has_value_; }
-
- private:
- // Unsafe access to value.
- RefType& raw_value() { return reinterpret_cast<RefType&>(storage_); }
-
- bool has_value_{};
- typename std::aligned_storage<sizeof(RefType), alignof(RefType)>::type
- storage_;
- };
-
- // Use mutable, reassignable ref to permit pointer-like semantics (with
- // ownership) on the stack.
- mutable ReassignableRef m_;
-
- // Consolidate assignment here, so that both the copy constructor and the
- // construction from another type may be used.
- template <typename PlainObjectTypeIn>
- EigenPtr& assign(const EigenPtr<PlainObjectTypeIn>& other) {
- if (other) {
- m_.set_value(&(*other));
- } else {
- m_.reset();
- }
- return *this;
- }
-
- // Consolidate getting a reference here.
- RefType& get_reference() const {
- if (!m_.has_value())
- throw std::runtime_error("EigenPtr: nullptr dereference");
- return m_.value();
- }
-
- bool is_valid() const {
- return m_.has_value();
- }
-};
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/common/test_utilities/eigen_matrix_compare.h b/wpimath/src/test/native/include/drake/common/test_utilities/eigen_matrix_compare.h
index 8847d87..d6bcbb8 100644
--- a/wpimath/src/test/native/include/drake/common/test_utilities/eigen_matrix_compare.h
+++ b/wpimath/src/test/native/include/drake/common/test_utilities/eigen_matrix_compare.h
@@ -7,7 +7,7 @@
#include <Eigen/Core>
#include <gtest/gtest.h>
-#include "drake/common/drake_nodiscard.h"
+// #include "drake/common/text_logging.h"
namespace drake {
@@ -28,7 +28,7 @@
* @return true if the two matrices are equal based on the specified tolerance.
*/
template <typename DerivedA, typename DerivedB>
-DRAKE_NODISCARD ::testing::AssertionResult CompareMatrices(
+[[nodiscard]] ::testing::AssertionResult CompareMatrices(
const Eigen::MatrixBase<DerivedA>& m1,
const Eigen::MatrixBase<DerivedB>& m2, double tolerance = 0.0,
MatrixCompareType compare_type = MatrixCompareType::absolute) {
diff --git a/wpimath/src/test/native/include/drake/common/unused.h b/wpimath/src/test/native/include/drake/common/unused.h
deleted file mode 100644
index 5a28b01..0000000
--- a/wpimath/src/test/native/include/drake/common/unused.h
+++ /dev/null
@@ -1,53 +0,0 @@
-#pragma once
-
-namespace drake {
-
-/// Documents the argument(s) as unused, placating GCC's -Wunused-parameter
-/// warning. This can be called within function bodies to mark that certain
-/// parameters are unused.
-///
-/// When possible, removing the unused parameter is better than placating the
-/// warning. However, in some cases the parameter is part of a virtual API or
-/// template concept that is used elsewhere, so we can't remove it. In those
-/// cases, this function might be an appropriate work-around.
-///
-/// Here's rough advice on how to fix Wunused-parameter warnings:
-///
-/// (1) If the parameter can be removed entirely, prefer that as the first
-/// choice. (This may not be possible if, e.g., a method must match some
-/// virtual API or template concept.)
-///
-/// (2) Unless the parameter name has acute value, prefer to omit the name of
-/// the parameter, leaving only the type, e.g.
-/// @code
-/// void Print(const State& state) override { /* No state to print. */ }
-/// @endcode
-/// changes to
-/// @code
-/// void Print(const State&) override { /* No state to print. */}
-/// @endcode
-/// This no longer triggers the warning and further makes it clear that a
-/// parameter required by the API is definitively unused in the function.
-///
-/// This is an especially good solution in the context of method
-/// definitions (vs declarations); the parameter name used in a definition
-/// is entirely irrelevant to Doxygen and most readers.
-///
-/// (3) When leaving the parameter name intact has acute value, it is
-/// acceptable to keep the name and mark it `unused`. For example, when
-/// the name appears as part of a virtual method's base class declaration,
-/// the name is used by Doxygen to document the method, e.g.,
-/// @code
-/// /** Sets the default State of a System. This default implementation is to
-/// set all zeros. Subclasses may override to use non-zero defaults. The
-/// custom defaults may be based on the given @p context, when relevant. */
-/// virtual void SetDefault(const Context<T>& context, State<T>* state) const {
-/// unused(context);
-/// state->SetZero();
-/// }
-/// @endcode
-///
-template <typename ... Args>
-void unused(const Args& ...) {}
-
-} // namespace drake
diff --git a/wpimath/src/test/native/include/drake/math/autodiff.h b/wpimath/src/test/native/include/drake/math/autodiff.h
deleted file mode 100644
index 52edb11..0000000
--- a/wpimath/src/test/native/include/drake/math/autodiff.h
+++ /dev/null
@@ -1,332 +0,0 @@
-/// @file
-/// Utilities for arithmetic on AutoDiffScalar.
-
-// TODO(russt): rename methods to be GSG compliant.
-
-#pragma once
-
-#include <cmath>
-#include <tuple>
-
-#include <Eigen/Core>
-
-#include "drake/common/autodiff.h"
-#include "drake/common/unused.h"
-
-namespace drake {
-namespace math {
-
-template <typename Derived>
-struct AutoDiffToValueMatrix {
- typedef typename Eigen::Matrix<typename Derived::Scalar::Scalar,
- Derived::RowsAtCompileTime,
- Derived::ColsAtCompileTime> type;
-};
-
-template <typename Derived>
-typename AutoDiffToValueMatrix<Derived>::type autoDiffToValueMatrix(
- const Eigen::MatrixBase<Derived>& auto_diff_matrix) {
- typename AutoDiffToValueMatrix<Derived>::type ret(auto_diff_matrix.rows(),
- auto_diff_matrix.cols());
- for (int i = 0; i < auto_diff_matrix.rows(); i++) {
- for (int j = 0; j < auto_diff_matrix.cols(); ++j) {
- ret(i, j) = auto_diff_matrix(i, j).value();
- }
- }
- return ret;
-}
-
-/** `B = DiscardGradient(A)` enables casting from a matrix of AutoDiffScalars
- * to AutoDiffScalar::Scalar type, explicitly throwing away any gradient
- * information. For a matrix of type, e.g. `MatrixX<AutoDiffXd> A`, the
- * comparable operation
- * `B = A.cast<double>()`
- * should (and does) fail to compile. Use `DiscardGradient(A)` if you want to
- * force the cast (and explicitly declare that information is lost).
- *
- * This method is overloaded to permit the user to call it for double types and
- * AutoDiffScalar types (to avoid the calling function having to handle the
- * two cases differently).
- *
- * @see DiscardZeroGradient
- */
-template <typename Derived>
-typename std::enable_if<
- !std::is_same<typename Derived::Scalar, double>::value,
- Eigen::Matrix<typename Derived::Scalar::Scalar, Derived::RowsAtCompileTime,
- Derived::ColsAtCompileTime, 0, Derived::MaxRowsAtCompileTime,
- Derived::MaxColsAtCompileTime>>::type
-DiscardGradient(const Eigen::MatrixBase<Derived>& auto_diff_matrix) {
- return autoDiffToValueMatrix(auto_diff_matrix);
-}
-
-/// @see DiscardGradient().
-template <typename Derived>
-typename std::enable_if<
- std::is_same<typename Derived::Scalar, double>::value,
- const Eigen::MatrixBase<Derived>&>::type
-DiscardGradient(const Eigen::MatrixBase<Derived>& matrix) {
- return matrix;
-}
-
-/// @see DiscardGradient().
-template <typename _Scalar, int _Dim, int _Mode, int _Options>
-typename std::enable_if<
- !std::is_same<_Scalar, double>::value,
- Eigen::Transform<typename _Scalar::Scalar, _Dim, _Mode, _Options>>::type
-DiscardGradient(const Eigen::Transform<_Scalar, _Dim, _Mode, _Options>&
- auto_diff_transform) {
- return Eigen::Transform<typename _Scalar::Scalar, _Dim, _Mode, _Options>(
- autoDiffToValueMatrix(auto_diff_transform.matrix()));
-}
-
-/// @see DiscardGradient().
-template <typename _Scalar, int _Dim, int _Mode, int _Options>
-typename std::enable_if<std::is_same<_Scalar, double>::value,
- const Eigen::Transform<_Scalar, _Dim, _Mode,
- _Options>&>::type
-DiscardGradient(
- const Eigen::Transform<_Scalar, _Dim, _Mode, _Options>& transform) {
- return transform;
-}
-
-
-/** \brief Initialize a single autodiff matrix given the corresponding value
- *matrix.
- *
- * Set the values of \p auto_diff_matrix to be equal to \p val, and for each
- *element i of \p auto_diff_matrix,
- * resize the derivatives vector to \p num_derivatives, and set derivative
- *number \p deriv_num_start + i to one (all other elements of the derivative
- *vector set to zero).
- *
- * \param[in] mat 'regular' matrix of values
- * \param[out] ret AutoDiff matrix
- * \param[in] num_derivatives the size of the derivatives vector @default the
- *size of mat
- * \param[in] deriv_num_start starting index into derivative vector (i.e.
- *element deriv_num_start in derivative vector corresponds to mat(0, 0)).
- *@default 0
- */
-template <typename Derived, typename DerivedAutoDiff>
-void initializeAutoDiff(const Eigen::MatrixBase<Derived>& val,
- // TODO(#2274) Fix NOLINTNEXTLINE(runtime/references).
- Eigen::MatrixBase<DerivedAutoDiff>& auto_diff_matrix,
- Eigen::DenseIndex num_derivatives = Eigen::Dynamic,
- Eigen::DenseIndex deriv_num_start = 0) {
- using ADScalar = typename DerivedAutoDiff::Scalar;
- static_assert(static_cast<int>(Derived::RowsAtCompileTime) ==
- static_cast<int>(DerivedAutoDiff::RowsAtCompileTime),
- "auto diff matrix has wrong number of rows at compile time");
- static_assert(static_cast<int>(Derived::ColsAtCompileTime) ==
- static_cast<int>(DerivedAutoDiff::ColsAtCompileTime),
- "auto diff matrix has wrong number of columns at compile time");
-
- if (num_derivatives == Eigen::Dynamic) num_derivatives = val.size();
-
- auto_diff_matrix.resize(val.rows(), val.cols());
- Eigen::DenseIndex deriv_num = deriv_num_start;
- for (Eigen::DenseIndex i = 0; i < val.size(); i++) {
- auto_diff_matrix(i) = ADScalar(val(i), num_derivatives, deriv_num++);
- }
-}
-
-/** \brief The appropriate AutoDiffScalar gradient type given the value type and
- * the number of derivatives at compile time
- */
-template <typename Derived, int Nq>
-using AutoDiffMatrixType = Eigen::Matrix<
- Eigen::AutoDiffScalar<Eigen::Matrix<typename Derived::Scalar, Nq, 1>>,
- Derived::RowsAtCompileTime, Derived::ColsAtCompileTime, 0,
- Derived::MaxRowsAtCompileTime, Derived::MaxColsAtCompileTime>;
-
-/** \brief Initialize a single autodiff matrix given the corresponding value
- *matrix.
- *
- * Create autodiff matrix that matches \p mat in size with derivatives of
- *compile time size \p Nq and runtime size \p num_derivatives.
- * Set its values to be equal to \p val, and for each element i of \p
- *auto_diff_matrix, set derivative number \p deriv_num_start + i to one (all
- *other derivatives set to zero).
- *
- * \param[in] mat 'regular' matrix of values
- * \param[in] num_derivatives the size of the derivatives vector @default the
- *size of mat
- * \param[in] deriv_num_start starting index into derivative vector (i.e.
- *element deriv_num_start in derivative vector corresponds to mat(0, 0)).
- *@default 0
- * \return AutoDiff matrix
- */
-template <int Nq = Eigen::Dynamic, typename Derived>
-AutoDiffMatrixType<Derived, Nq> initializeAutoDiff(
- const Eigen::MatrixBase<Derived>& mat,
- Eigen::DenseIndex num_derivatives = -1,
- Eigen::DenseIndex deriv_num_start = 0) {
- if (num_derivatives == -1) num_derivatives = mat.size();
-
- AutoDiffMatrixType<Derived, Nq> ret(mat.rows(), mat.cols());
- initializeAutoDiff(mat, ret, num_derivatives, deriv_num_start);
- return ret;
-}
-
-namespace internal {
-template <typename Derived, typename Scalar>
-struct ResizeDerivativesToMatchScalarImpl {
- // TODO(#2274) Fix NOLINTNEXTLINE(runtime/references).
- static void run(Eigen::MatrixBase<Derived>&, const Scalar&) {}
-};
-
-template <typename Derived, typename DerivType>
-struct ResizeDerivativesToMatchScalarImpl<Derived,
- Eigen::AutoDiffScalar<DerivType>> {
- using Scalar = Eigen::AutoDiffScalar<DerivType>;
- // TODO(#2274) Fix NOLINTNEXTLINE(runtime/references).
- static void run(Eigen::MatrixBase<Derived>& mat, const Scalar& scalar) {
- for (int i = 0; i < mat.size(); i++) {
- auto& derivs = mat(i).derivatives();
- if (derivs.size() == 0) {
- derivs.resize(scalar.derivatives().size());
- derivs.setZero();
- }
- }
- }
-};
-} // namespace internal
-
-/** Resize derivatives vector of each element of a matrix to to match the size
- * of the derivatives vector of a given scalar.
- * \brief If the mat and scalar inputs are AutoDiffScalars, resize the
- * derivatives vector of each element of the matrix mat to match
- * the number of derivatives of the scalar. This is useful in functions that
- * return matrices that do not depend on an AutoDiffScalar
- * argument (e.g. a function with a constant output), while it is desired that
- * information about the number of derivatives is preserved.
- * \param mat matrix, for which the derivative vectors of the elements will be
- * resized
- * \param scalar scalar to match the derivative size vector against.
- */
-template <typename Derived>
-// TODO(#2274) Fix NOLINTNEXTLINE(runtime/references).
-void resizeDerivativesToMatchScalar(Eigen::MatrixBase<Derived>& mat,
- const typename Derived::Scalar& scalar) {
- internal::ResizeDerivativesToMatchScalarImpl<
- Derived, typename Derived::Scalar>::run(mat, scalar);
-}
-
-namespace internal {
-/** \brief Helper for totalSizeAtCompileTime function (recursive)
- */
-template <typename Head, typename... Tail>
-struct TotalSizeAtCompileTime {
- static constexpr int eval() {
- return Head::SizeAtCompileTime == Eigen::Dynamic ||
- TotalSizeAtCompileTime<Tail...>::eval() == Eigen::Dynamic
- ? Eigen::Dynamic
- : Head::SizeAtCompileTime +
- TotalSizeAtCompileTime<Tail...>::eval();
- }
-};
-
-/** \brief Helper for totalSizeAtCompileTime function (base case)
- */
-template <typename Head>
-struct TotalSizeAtCompileTime<Head> {
- static constexpr int eval() { return Head::SizeAtCompileTime; }
-};
-
-/** \brief Determine the total size at compile time of a number of arguments
- * based on their SizeAtCompileTime static members
- */
-template <typename... Args>
-constexpr int totalSizeAtCompileTime() {
- return TotalSizeAtCompileTime<Args...>::eval();
-}
-
-/** \brief Determine the total size at runtime of a number of arguments using
- * their size() methods (base case).
- */
-constexpr Eigen::DenseIndex totalSizeAtRunTime() { return 0; }
-
-/** \brief Determine the total size at runtime of a number of arguments using
- * their size() methods (recursive)
- */
-template <typename Head, typename... Tail>
-Eigen::DenseIndex totalSizeAtRunTime(const Eigen::MatrixBase<Head>& head,
- const Tail&... tail) {
- return head.size() + totalSizeAtRunTime(tail...);
-}
-
-/** \brief Helper for initializeAutoDiffTuple function (recursive)
- */
-template <size_t Index>
-struct InitializeAutoDiffTupleHelper {
- template <typename... ValueTypes, typename... AutoDiffTypes>
- static void run(const std::tuple<ValueTypes...>& values,
- // TODO(#2274) Fix NOLINTNEXTLINE(runtime/references).
- std::tuple<AutoDiffTypes...>& auto_diffs,
- Eigen::DenseIndex num_derivatives,
- Eigen::DenseIndex deriv_num_start) {
- constexpr size_t tuple_index = sizeof...(AutoDiffTypes)-Index;
- const auto& value = std::get<tuple_index>(values);
- auto& auto_diff = std::get<tuple_index>(auto_diffs);
- auto_diff.resize(value.rows(), value.cols());
- initializeAutoDiff(value, auto_diff, num_derivatives, deriv_num_start);
- InitializeAutoDiffTupleHelper<Index - 1>::run(
- values, auto_diffs, num_derivatives, deriv_num_start + value.size());
- }
-};
-
-/** \brief Helper for initializeAutoDiffTuple function (base case)
- */
-template <>
-struct InitializeAutoDiffTupleHelper<0> {
- template <typename... ValueTypes, typename... AutoDiffTypes>
- static void run(const std::tuple<ValueTypes...>& values,
- const std::tuple<AutoDiffTypes...>& auto_diffs,
- Eigen::DenseIndex num_derivatives,
- Eigen::DenseIndex deriv_num_start) {
- unused(values, auto_diffs, num_derivatives, deriv_num_start);
- }
-};
-} // namespace internal
-
-/** \brief Given a series of Eigen matrices, create a tuple of corresponding
- *AutoDiff matrices with values equal to the input matrices and properly
- *initialized derivative vectors.
- *
- * The size of the derivative vector of each element of the matrices in the
- *output tuple will be the same, and will equal the sum of the number of
- *elements of the matrices in \p args.
- * If all of the matrices in \p args have fixed size, then the derivative
- *vectors will also have fixed size (being the sum of the sizes at compile time
- *of all of the input arguments),
- * otherwise the derivative vectors will have dynamic size.
- * The 0th element of the derivative vectors will correspond to the derivative
- *with respect to the 0th element of the first argument.
- * Subsequent derivative vector elements correspond first to subsequent elements
- *of the first input argument (traversed first by row, then by column), and so
- *on for subsequent arguments.
- *
- * \param args a series of Eigen matrices
- * \return a tuple of properly initialized AutoDiff matrices corresponding to \p
- *args
- *
- */
-template <typename... Deriveds>
-std::tuple<AutoDiffMatrixType<
- Deriveds, internal::totalSizeAtCompileTime<Deriveds...>()>...>
-initializeAutoDiffTuple(const Eigen::MatrixBase<Deriveds>&... args) {
- Eigen::DenseIndex dynamic_num_derivs = internal::totalSizeAtRunTime(args...);
- std::tuple<AutoDiffMatrixType<
- Deriveds, internal::totalSizeAtCompileTime<Deriveds...>()>...>
- ret(AutoDiffMatrixType<Deriveds,
- internal::totalSizeAtCompileTime<Deriveds...>()>(
- args.rows(), args.cols())...);
- auto values = std::forward_as_tuple(args...);
- internal::InitializeAutoDiffTupleHelper<sizeof...(args)>::run(
- values, ret, dynamic_num_derivs, 0);
- return ret;
-}
-
-} // namespace math
-} // namespace drake
diff --git a/wpimath/src/test/native/include/frc/system/RungeKuttaTimeVarying.h b/wpimath/src/test/native/include/frc/system/RungeKuttaTimeVarying.h
new file mode 100644
index 0000000..23a20e8
--- /dev/null
+++ b/wpimath/src/test/native/include/frc/system/RungeKuttaTimeVarying.h
@@ -0,0 +1,34 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+#pragma once
+
+#include <array>
+
+#include "Eigen/Core"
+#include "units/time.h"
+
+namespace frc {
+
+/**
+ * Performs 4th order Runge-Kutta integration of dy/dt = f(t, y) for dt.
+ *
+ * @param f The function to integrate. It must take two arguments t and y.
+ * @param t The initial value of t.
+ * @param y The initial value of y.
+ * @param dt The time over which to integrate.
+ */
+template <typename F, typename T>
+T RungeKuttaTimeVarying(F&& f, units::second_t t, T y, units::second_t dt) {
+ const auto h = dt.value();
+
+ T k1 = f(t, y);
+ T k2 = f(t + dt * 0.5, y + h * k1 * 0.5);
+ T k3 = f(t + dt * 0.5, y + h * k2 * 0.5);
+ T k4 = f(t + dt, y + h * k3);
+
+ return y + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
+}
+
+} // namespace frc
diff --git a/wpimath/src/test/native/include/trajectory/TestTrajectory.h b/wpimath/src/test/native/include/trajectory/TestTrajectory.h
index 1cac87f..de7b8b8 100644
--- a/wpimath/src/test/native/include/trajectory/TestTrajectory.h
+++ b/wpimath/src/test/native/include/trajectory/TestTrajectory.h
@@ -1,9 +1,6 @@
-/*----------------------------------------------------------------------------*/
-/* Copyright (c) 2019-2020 FIRST. All Rights Reserved. */
-/* Open Source Software - may be modified and shared by FRC teams. The code */
-/* must be accompanied by the FIRST BSD license file in the root directory of */
-/* the project. */
-/*----------------------------------------------------------------------------*/
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
#pragma once
diff --git a/wpimath/wpimath-config.cmake.in b/wpimath/wpimath-config.cmake.in
index 2f661d9..4769e43 100644
--- a/wpimath/wpimath-config.cmake.in
+++ b/wpimath/wpimath-config.cmake.in
@@ -2,4 +2,5 @@
@FILENAME_DEP_REPLACE@
@WPIUTIL_DEP_REPLACE@
+@FILENAME_DEP_REPLACE@
include(${SELF_DIR}/wpimath.cmake)