polish code before testing

Signed-off-by: AlienTornadosaurusHex <>
This commit is contained in:
AlienTornadosaurusHex 2023-06-07 19:49:58 +00:00
parent fce35d2a4f
commit 6bdd7a26e7
5 changed files with 151 additions and 78 deletions

View File

@ -115,7 +115,7 @@ library CurveChainedOracles {
// mload always reads bytes32, always offset by 26 bytes because data is always stored fully
// packed from left to right
assembly {
chunk := mload(add(_oracle, mul(26, add(o, 1))))
chunk := mload(add(_oracle, add(32, mul(26, o))))
}
// Properly read in all of the packed data
@ -135,6 +135,7 @@ library CurveChainedOracles {
uint256 priceFromOracle;
// Execute according to selector
// These are guaranteed to all be 1e18
if (selector == PRICE_ORACLE_SELECTOR) {
priceFromOracle = ICurvePriceOracle(oracle).price_oracle();
} else if (selector == PRICE_ORACLE_UINT256_SELECTOR) {
@ -174,7 +175,7 @@ contract CurveFeeOracle is IFeeOracle {
/* @dev For each instance, a set of data which translates to a set of chained price oracle calls, we call
this data "chainedPriceOracles" because it encodes all data necessary to execute the chain and calc the
price */
mapping(ITornadoInstance => bytes) public chainedPriceOracles;
mapping(ITornadoInstance => bytes) internal chainedPriceOracles;
/* @dev When setting, store the names as a historical record, key is keccak256(bytes) */
mapping(bytes32 => string) public chainedPriceOracleNames;
@ -265,7 +266,7 @@ contract CurveFeeOracle is IFeeOracle {
/* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SETTERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
function setChainedOracleForInstance(
function modifyChainedOracleForInstance(
ITornadoInstance _instance,
ICurvePriceOracle[] memory _oracles,
bytes4[] memory _selectors,

View File

@ -18,12 +18,6 @@ import { InstanceRegistry } from "./InstanceRegistry.sol";
import { IFeeOracle } from "./interfaces/IFeeOracle.sol";
struct FeeData {
uint160 feeAmount;
uint64 feePercent;
uint32 lastUpdated;
}
/**
* @title FeeManagerLegacyStorage
* @author AlienTornadosaurusHex
@ -47,6 +41,15 @@ contract FeeManagerLegacyStorage {
mapping(ITornadoInstance => uint256) internal _oldFeesForInstanceUpdateTime;
}
/**
* @notice Fee data which is valid and also influences fee updating across all oracles.
*/
struct FeeData {
uint160 feeAmount;
uint64 feePercent;
uint32 lastUpdated;
}
/**
* @title FeeOracleManager
* @author AlienTornadosaurusHex
@ -111,7 +114,7 @@ contract FeeOracleManager is FeeManagerLegacyStorage, Initializable {
// For each instance
ITornadoInstance instance = ITornadoInstance(_instanceAddresses[i]);
// Store it's old data and the percent fees which will Governance will command
// Store it's old data and the percent fees which Governance will command
feeDataForInstance[instance] = FeeData({
feeAmount: _oldFeesForInstance[instance],
feePercent: uint64(_feePercents[i]),
@ -161,7 +164,7 @@ contract FeeOracleManager is FeeManagerLegacyStorage, Initializable {
FeeData memory feeData = feeDataForInstance[_instance];
// Now update if we do not respect the interval or we respect it and are in the interval
if (!_respectFeeUpdateInterval || feeUpdateInterval < -feeData.lastUpdated + now) {
if (!_respectFeeUpdateInterval || (feeUpdateInterval < -feeData.lastUpdated + now)) {
// This will revert if no contract is set
feeData.feeAmount = instanceFeeOracles[_instance].getFee(
torn,
@ -194,26 +197,36 @@ contract FeeOracleManager is FeeManagerLegacyStorage, Initializable {
}
function setFeeOracle(address _instanceAddress, address _oracleAddress) external onlyGovernance {
// Prepare all contracts
ITornadoInstance instance = ITornadoInstance(_instanceAddress);
IFeeOracle oracle = IFeeOracle(_oracleAddress);
// Nominally fee percent should be set first for an instance, but we cannot be sure
// whether fee percent 0 is intentional, so we don't check
FeeData memory feeData = feeDataForInstance[instance];
// Reverts if no oracle or does not conform to interface
uint160 fee = oracle.getFee(
torn,
instance,
instanceRegistry.getInstanceData(instance),
feeData.feePercent,
FEE_PERCENT_DIVISOR
);
// Fee which may be recalculated
uint160 fee;
// An address(0) oracle means we're removing an oracle for an instance
if (_oracleAddress != address(0)) {
// Reverts if oracle does not conform to interface
fee = oracle.getFee(
torn,
instance,
instanceRegistry.getInstanceData(instance),
feeData.feePercent,
FEE_PERCENT_DIVISOR
);
// Note down updated fee
feeDataForInstance[instance] =
FeeData({ feeAmount: fee, feePercent: feeData.feePercent, lastUpdated: uint32(now) });
}
// Ok, set the oracle
instanceFeeOracles[instance] = oracle;
// Note down updated fee
feeDataForInstance[instance] =
FeeData({ feeAmount: fee, feePercent: feeData.feePercent, lastUpdated: uint32(now) });
// Logs
emit OracleUpdated(_instanceAddress, _oracleAddress);
emit FeeUpdated(_instanceAddress, fee);
@ -249,22 +262,33 @@ contract FeeOracleManager is FeeManagerLegacyStorage, Initializable {
}
function getLastFeeForInstance(ITornadoInstance instance) public view virtual returns (uint160) {
return uint160(feeDataForInstance[instance].feeAmount);
return feeDataForInstance[instance].feeAmount;
}
function getLastUpdatedTimeForInstance(ITornadoInstance instance) public view virtual returns (uint32) {
return feeDataForInstance[instance].lastUpdated;
}
function getFeeDeviations() public view virtual returns (int256[] memory deviations) {
ITornadoInstance[] memory instances = instanceRegistry.getAllInstances();
function getFeePercentForInstance(ITornadoInstance instance) public view virtual returns (uint64) {
return feeDataForInstance[instance].feePercent;
}
uint256 numInstances = instances.length;
function getAllFeeDeviations() public view virtual returns (int256[] memory) {
return getFeeDeviationsForInstances(instanceRegistry.getAllInstances());
}
function getFeeDeviationsForInstances(ITornadoInstance[] memory _instances)
public
view
virtual
returns (int256[] memory deviations)
{
uint256 numInstances = _instances.length;
deviations = new int256[](numInstances);
for (uint256 i = 0; i < numInstances; i++) {
ITornadoInstance instance = instances[i];
ITornadoInstance instance = _instances[i];
FeeData memory feeData = feeDataForInstance[instance];
@ -279,7 +303,7 @@ contract FeeOracleManager is FeeManagerLegacyStorage, Initializable {
int256 deviation;
if (marketFee != 0) {
deviation = int256((feeDataForInstance[instance].feeAmount * 1000) / marketFee) - 1000;
deviation = int256((feeData.feeAmount * 1000) / marketFee) - 1000;
}
deviations[i] = deviation;

View File

@ -111,29 +111,33 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
_;
}
function initialize(ITornadoInstance[] memory _instances, address _router)
function initialize(ITornadoInstance[] memory _instances, TornadoRouter _router)
external
onlyGovernance
initializer
{
uint256 numInstances = _instances.length;
router = TornadoRouter(_router);
for (uint256 i = 0; i < numInstances; i++) {
addInstance(_instances[i]);
}
router = _router;
}
function addInstance(ITornadoInstance _instance) public virtual onlyGovernance {
// The instance may not already be enabled
bool isEnabled = instanceData[_instance].isEnabled;
require(!isEnabled, "InstanceRegistry: can't add the same instance.");
bool isERC20 = false;
// Determine whether it is an ERC20 or not
IERC20 token = IERC20(address(0));
bool isERC20 = false;
// ETH instances do not know of a `token()` call
try _instance.token() returns (address _tokenAddress) {
token = IERC20(_tokenAddress);
@ -143,7 +147,7 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
}
// If it is ERC20 then make the router give an approval for the Tornado instance to allow the token
// amount... if it hasn't already done so
// amount, if it hasn't already done so
if (isERC20) {
uint256 routerAllowanceForInstance = token.allowance(address(router), address(_instance));
@ -155,11 +159,12 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
// Add it to the enumerable
instances.push(_instance);
// Read out the index of the instance in the enumerable
uint64 instanceIndex = uint64(instances.length - 1);
// Set data
// Store the collected data of the instance
instanceData[_instance] =
InstanceData({ token: token, index: instanceIndex, isERC20: isERC20, isEnabled: isEnabled });
InstanceData({ token: token, index: instanceIndex, isERC20: isERC20, isEnabled: true });
// Log
emit InstanceAdded(address(_instance), instanceIndex, isERC20);
@ -167,7 +172,7 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
/**
* @notice Remove an instance, only callable by Governance.
* @dev The modifier is in the internal call.
* @dev The access modifier is in the internal call.
* @param _instanceIndex The index of the instance to remove.
*/
function removeInstanceByIndex(uint256 _instanceIndex) public virtual {
@ -176,7 +181,7 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
/**
* @notice Remove an instance, only callable by Governance.
* @dev The modifier is in the internal call.
* @dev The access modifier is in the internal call.
* @param _instanceAddress The adress of the instance to remove.
*/
function removeInstanceByAddress(address _instanceAddress) public virtual {
@ -184,6 +189,8 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
}
function _removeInstanceByAddress(address _instanceAddress) internal virtual onlyGovernance {
// Grab data needed to remove the instance
ITornadoInstance instance = ITornadoInstance(_instanceAddress);
InstanceData memory data = instanceData[instance];
@ -238,25 +245,43 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
}
function getAllInstances() public view virtual returns (ITornadoInstance[] memory allInstances) {
uint256 numInstances = instances.length;
return getInstances(0, instances.length - 1);
}
allInstances = new ITornadoInstance[](instances.length);
function getInstances(uint256 _inclusiveStartIndex, uint256 _inclusiveEndIndex)
public
view
virtual
returns (ITornadoInstance[] memory allInstances)
{
allInstances = new ITornadoInstance[](-_inclusiveStartIndex + 1 + _inclusiveEndIndex);
for (uint256 i = 0; i < numInstances; i++) {
for (uint256 i = _inclusiveStartIndex; i < _inclusiveEndIndex + 1; i++) {
allInstances[i] = instances[i];
}
}
function getAllInstanceData() public view virtual returns (InstanceData[] memory data) {
uint256 numInstances = instances.length;
return getInstanceData(0, instances.length - 1);
}
data = new InstanceData[](numInstances);
function getInstanceData(uint256 _inclusiveStartIndex, uint256 _inclusiveEndIndex)
public
view
virtual
returns (InstanceData[] memory data)
{
data = new InstanceData[](-_inclusiveStartIndex + 1 + _inclusiveEndIndex);
for (uint256 i = 0; i < numInstances; i++) {
for (uint256 i = _inclusiveStartIndex; i < _inclusiveEndIndex + 1; i++) {
data[i] = instanceData[instances[i]];
}
}
function getInstanceData(uint256 _index) public view virtual returns (InstanceData memory data) {
return instanceData[instances[_index]];
}
function getInstanceData(ITornadoInstance _instance)
public
view
@ -274,26 +299,16 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
return instanceData[_instance].index;
}
function isRegisteredInstance(address _instanceAddress) public view virtual returns (bool) {
return isEnabledInstance(_instanceAddress);
function isRegisteredInstance(ITornadoInstance _instance) public view virtual returns (bool) {
return isEnabledInstance(_instance);
}
function isEnabledInstance(address _instanceAddress) public view virtual returns (bool) {
return instanceData[ITornadoInstance(_instanceAddress)].isEnabled;
function isEnabledInstance(ITornadoInstance _instance) public view virtual returns (bool) {
return instanceData[_instance].isEnabled;
}
/* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ENS GETTERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
function getInstanceByENSName(string memory _instanceENSName)
public
view
virtual
returns (ITornadoInstance)
{
(, bytes32 node) = NameEncoder.dnsEncodeName(_instanceENSName);
return ITornadoInstance(resolve(node));
}
function isRegisteredInstanceByENSName(string memory _instanceENSName)
public
view
@ -304,6 +319,16 @@ contract InstanceRegistry is InstanceRegistryLegacyStorage, EnsResolve, Initiali
}
function isEnabledInstanceByENSName(string memory _instanceENSName) public view virtual returns (bool) {
return isEnabledInstance(address(getInstanceByENSName(_instanceENSName)));
return isEnabledInstance(getInstanceByENSName(_instanceENSName));
}
function getInstanceByENSName(string memory _instanceENSName)
public
view
virtual
returns (ITornadoInstance)
{
(, bytes32 node) = NameEncoder.dnsEncodeName(_instanceENSName);
return ITornadoInstance(resolve(node));
}
}

View File

@ -53,13 +53,13 @@ contract TornadoRouter is Initializable {
governanceProxyAddress = _governanceProxyAddress;
}
modifier onlyGovernance() {
require(msg.sender == governanceProxyAddress, "TornadoRouter: onlyGovernance");
modifier onlyInstanceRegistry() {
require(msg.sender == address(instanceRegistry), "TornadoRouter: onlyInstanceRegistry");
_;
}
modifier onlyInstanceRegistry() {
require(msg.sender == address(instanceRegistry), "TornadoRouter: onlyInstanceRegistry");
modifier onlyGovernance() {
require(msg.sender == governanceProxyAddress, "TornadoRouter: onlyGovernance");
_;
}

View File

@ -125,27 +125,30 @@ contract UniswapV3FeeOracle is IFeeOracle {
GlobalOracleConfig memory global = globals;
// Check whether all globals needed are initialized
// Check whether globals are initialized
require(global.tornPoolFee != 0, "UniswapV3FeeOracle: torn pool fee not initialized");
require(global.twapIntervalSeconds != 0, "UniswapV3FeeOracle: time period not initialized");
require(global.minObservationCardinality != 0, "UniswapV3FeeOracle: cardinality not initialized");
// Check whether a pool exists for the token + fee combination
// Only do this if not zeroing out
if (_tokenPoolFee != 0) {
// Check whether a pool exists for the token + fee combination
address poolAddress = UniswapV3OracleHelper.UniswapV3Factory.getPool(
address(_token), UniswapV3OracleHelper.WETH, _tokenPoolFee
);
address poolAddress = UniswapV3OracleHelper.UniswapV3Factory.getPool(
address(_token), UniswapV3OracleHelper.WETH, _tokenPoolFee
);
require(poolAddress != address(0), "UniswapV3FeeOracle: pool for token and fee does not exist");
require(poolAddress != address(0), "UniswapV3FeeOracle: pool for token and fee does not exist");
// Check whether the pool has a large enough observation cardinality
// Check whether the pool has a large enough observation cardinality
(,,,, uint16 observationCardinalityNext,,) = IUniswapV3PoolState(poolAddress).slot0();
(,,,, uint16 observationCardinalityNext,,) = IUniswapV3PoolState(poolAddress).slot0();
require(
global.minObservationCardinality <= observationCardinalityNext,
"UniswapV3FeeOracle: pool observation cardinality low"
);
require(
global.minObservationCardinality <= observationCardinalityNext,
"UniswapV3FeeOracle: pool observation cardinality low"
);
}
// Store & log
@ -154,11 +157,17 @@ contract UniswapV3FeeOracle is IFeeOracle {
emit PoolFeeUpdated(_token, _tokenPoolFee);
}
function setGlobalTornPoolFee(uint24 _newGlobalTornPoolFee) public virtual onlyGovernance {
function setGlobalTornPoolFee(uint24 _newGlobalTornPoolFee, bool _setSpecific)
public
virtual
onlyGovernance
{
globals.tornPoolFee = _newGlobalTornPoolFee;
// For `getPriceRatioOfTokens`
poolFeesByToken[IERC20(0x77777FeDdddFfC19Ff86DB637967013e6C6A116C)] = _newGlobalTornPoolFee;
if (_setSpecific) {
poolFeesByToken[IERC20(0x77777FeDdddFfC19Ff86DB637967013e6C6A116C)] = _newGlobalTornPoolFee;
}
emit GlobalTornPoolFeeUpdated(_newGlobalTornPoolFee);
}
@ -180,4 +189,18 @@ contract UniswapV3FeeOracle is IFeeOracle {
globals.minObservationCardinality = _newGlobalMinObservationCardinality;
emit GlobalMinObservationCardinalityUpdated(_newGlobalMinObservationCardinality);
}
/* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GETTERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
function getGlobalTornPoolFee() public view virtual returns (uint24) {
return globals.tornPoolFee;
}
function getGlobalTwapIntervalSeconds() public view virtual returns (uint32) {
return globals.twapIntervalSeconds;
}
function getGlobalMinObservationCardinality() public view virtual returns (uint16) {
return globals.minObservationCardinality;
}
}