lib: refactor rpc client to be fully async

This commit is contained in:
Oscar Mira 2023-03-08 18:57:33 +01:00
parent 4fc05895cf
commit 518e9234fe
23 changed files with 567 additions and 360 deletions

View File

@ -2,9 +2,9 @@ package im.molly.monero.demo.data
import android.content.Context
import im.molly.monero.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import im.molly.monero.loadbalancer.RoundRobinRule
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow
import okhttp3.OkHttpClient
class MoneroSdkClient(
@ -13,32 +13,37 @@ class MoneroSdkClient(
private val httpClient: OkHttpClient,
private val ioDispatcher: CoroutineDispatcher = Dispatchers.IO,
) {
suspend fun createWallet(moneroNetwork: MoneroNetwork): MoneroWallet =
withContext(ioDispatcher) {
val walletClient = WalletClient.forNetwork(
context = context,
network = moneroNetwork,
)
val wallet = walletClient.createNewWallet()
private val providerDeferred = CoroutineScope(ioDispatcher).async {
WalletProvider.connect(context)
}
suspend fun createWallet(moneroNetwork: MoneroNetwork): MoneroWallet {
val provider = providerDeferred.await()
return withContext(ioDispatcher) {
val wallet = provider.createNewWallet(moneroNetwork)
walletDataFileStorage.tryWriteData(wallet.publicAddress, false) { output ->
walletClient.saveWallet(wallet, output)
provider.saveWallet(wallet, output)
}
wallet
}
}
suspend fun openWallet(
publicAddress: String,
remoteNodeSelector: RemoteNodeSelector?,
): MoneroWallet =
withContext(ioDispatcher) {
val walletClient = WalletClient.forNetwork(
context = context,
network = MoneroNetwork.of(publicAddress),
nodeSelector = remoteNodeSelector,
remoteNodes: Flow<List<RemoteNode>>,
): MoneroWallet {
val provider = providerDeferred.await()
return withContext(ioDispatcher) {
val network = MoneroNetwork.of(publicAddress)
val client = RemoteNodeClient.forNetwork(
network = network,
remoteNodes = remoteNodes,
loadBalancerRule = RoundRobinRule(),
httpClient = httpClient,
)
walletDataFileStorage.readData(publicAddress).use { input ->
walletClient.openWallet(input)
provider.openWallet(network, client, input)
}
}
}
}

View File

@ -2,11 +2,7 @@ package im.molly.monero.demo.data
import im.molly.monero.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flow
import java.io.IOException
import kotlinx.coroutines.flow.*
import java.util.concurrent.ConcurrentHashMap
class WalletRepository(
@ -19,22 +15,16 @@ class WalletRepository(
suspend fun getWallet(walletId: Long): MoneroWallet {
return walletIdMap.computeIfAbsent(walletId) {
externalScope.async {
val config = getWalletConfig(walletId).first()
val config = getWalletConfig(walletId)
val wallet = moneroSdkClient.openWallet(
publicAddress = config.publicAddress,
remoteNodeSelector = object : RemoteNodeSelector {
override fun select(): RemoteNode? {
return config.remoteNodes.firstOrNull()?.let {
RemoteNode(
uri = it.uri,
username = it.username,
password = it.password,
)
}
}
override fun connectFailed(remoteNode: RemoteNode, ioe: IOException) {
TODO("Not yet implemented")
publicAddress = config.first().publicAddress,
remoteNodes = config.map {
it.remoteNodes.map { node ->
RemoteNode(
uri = node.uri,
username = node.username,
password = node.password,
)
}
}
)

View File

@ -40,7 +40,7 @@ class SyncService(
val wallet = walletRepository.getWallet(walletId)
lifecycleScope.launch {
while (isActive) {
val result = wallet.awaitRefresh();
val result = wallet.awaitRefresh()
if (result.isError()) {
break;
}

View File

@ -0,0 +1,6 @@
package im.molly.monero;
oneway interface IHttpRequestCallback {
void onResponse(int code, String contentType, in ParcelFileDescriptor body);
void onFailure();
}

View File

@ -1,10 +1,8 @@
package im.molly.monero;
import im.molly.monero.HttpResponse;
import im.molly.monero.RemoteNode;
import im.molly.monero.IHttpRequestCallback;
interface IRemoteNodeClient {
// RemoteNode getRemoteNode();
HttpResponse makeRequest(String method, String path, String header, in byte[] body);
oneway void cancelAll();
oneway void requestAsync(int requestId, String method, String path, String header, in byte[] bodyBytes, in IHttpRequestCallback callback);
oneway void cancelAsync(int requestId);
}

View File

@ -9,8 +9,8 @@ interface IWallet {
void addBalanceListener(in IBalanceListener listener);
void removeBalanceListener(in IBalanceListener listener);
void save(in ParcelFileDescriptor destination);
oneway void restartRefresh(boolean skipCoinbaseOutputs, in IRefreshCallback callback);
oneway void stopRefresh();
oneway void resumeRefresh(boolean skipCoinbaseOutputs, in IRefreshCallback callback);
void cancelRefresh();
void setRefreshSince(long heightOrTimestamp);
void close();
}

View File

@ -57,8 +57,8 @@ bool RemoteNodeClient::invoke(const boost::string_ref uri,
}
try {
ScopedJvmLocalRef<jobject> j_response = {
env, m_remote_node_client.callObjectMethod(
env, IRemoteNodeClient_makeRequest,
env, m_wallet_native.callObjectMethod(
env, WalletNative_callRemoteNode,
nativeToJvmString(env, method.data()).obj(),
nativeToJvmString(env, uri.data()).obj(),
nativeToJvmString(env, header.str()).obj(),
@ -80,7 +80,7 @@ bool RemoteNodeClient::invoke(const boost::string_ref uri,
http_response.body.read(&m_response_info.m_body);
}
} catch (std::runtime_error& e) {
LOGE("Unhandled exception in RemoteNodeClient");
LOGE("Unhandled exception: %s", e.what());
return false;
}
if (ppresponse_info) {

View File

@ -14,8 +14,8 @@ class RemoteNodeClient : public AbstractHttpClient {
public:
RemoteNodeClient(
JNIEnv* env,
const JvmRef<jobject>& remote_node_client)
: m_remote_node_client(env, remote_node_client) {}
const JvmRef<jobject>& wallet_native)
: m_wallet_native(env, wallet_native) {}
bool set_proxy(const std::string& address) override;
void set_server(std::string host,
@ -53,7 +53,7 @@ class RemoteNodeClient : public AbstractHttpClient {
};
private:
const ScopedJvmGlobalRef<jobject> m_remote_node_client;
const ScopedJvmGlobalRef<jobject> m_wallet_native;
epee::net_utils::http::http_response_info m_response_info;
};
@ -63,16 +63,16 @@ class RemoteNodeClientFactory : public HttpClientFactory {
public:
RemoteNodeClientFactory(
JNIEnv* env,
const JvmRef<jobject>& remote_node_client)
: m_remote_node_client(env, remote_node_client) {}
const JvmRef<jobject>& wallet_native)
: m_wallet_native(env, wallet_native) {}
std::unique_ptr<AbstractHttpClient> create() override {
return std::unique_ptr<AbstractHttpClient>(new RemoteNodeClient(getJniEnv(),
m_remote_node_client));
m_wallet_native));
}
private:
const ScopedJvmGlobalRef<jobject> m_remote_node_client;
const ScopedJvmGlobalRef<jobject> m_wallet_native;
};
} // namespace monero

View File

@ -7,11 +7,11 @@ ScopedJvmGlobalRef<jclass> OwnedTxOut;
jmethodID HttpResponse_getBody;
jmethodID HttpResponse_getCode;
jmethodID HttpResponse_getContentType;
jmethodID IRemoteNodeClient_cancelAll;
jmethodID IRemoteNodeClient_makeRequest;
jmethodID Logger_logFromNative;
jmethodID OwnedTxOut_ctor;
jmethodID Wallet_onRefresh;
jmethodID WalletNative_callRemoteNode;
jmethodID WalletNative_onRefresh;
jmethodID WalletNative_onSuspendRefresh;
// android.os
jmethodID ParcelFileDescriptor_detachFd;
@ -19,10 +19,9 @@ jmethodID ParcelFileDescriptor_detachFd;
void initializeJniCache(JNIEnv* env) {
// im.molly.monero
auto httpResponse = findClass(env, "im/molly/monero/HttpResponse");
auto iRemoteNodeClient = findClass(env, "im/molly/monero/IRemoteNodeClient");
auto logger = findClass(env, "im/molly/monero/Logger");
auto ownedTxOut = findClass(env, "im/molly/monero/OwnedTxOut");
auto wallet = findClass(env, "im/molly/monero/WalletNative");
auto walletNative = findClass(env, "im/molly/monero/WalletNative");
HttpResponse_getBody = httpResponse
.getMethodId(env, "getBody", "()Landroid/os/ParcelFileDescriptor;");
@ -30,18 +29,18 @@ void initializeJniCache(JNIEnv* env) {
.getMethodId(env, "getCode", "()I");
HttpResponse_getContentType = httpResponse
.getMethodId(env, "getContentType", "()Ljava/lang/String;");
IRemoteNodeClient_cancelAll = iRemoteNodeClient
.getMethodId(env, "cancelAll", "()V");
IRemoteNodeClient_makeRequest = iRemoteNodeClient
.getMethodId(env,
"makeRequest",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;[B)Lim/molly/monero/HttpResponse;");
Logger_logFromNative = logger
.getMethodId(env, "logFromNative", "(ILjava/lang/String;Ljava/lang/String;)V");
OwnedTxOut_ctor = ownedTxOut
.getMethodId(env, "<init>", "([BJJJ)V");
Wallet_onRefresh = wallet
WalletNative_callRemoteNode = walletNative
.getMethodId(env,
"callRemoteNode",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;[B)Lim/molly/monero/HttpResponse;");
WalletNative_onRefresh = walletNative
.getMethodId(env, "onRefresh", "(JZ)V");
WalletNative_onSuspendRefresh = walletNative
.getMethodId(env, "onSuspendRefresh", "(Z)V");
OwnedTxOut = ownedTxOut;

View File

@ -13,11 +13,11 @@ extern ScopedJvmGlobalRef<jclass> OwnedTxOut;
extern jmethodID HttpResponse_getBody;
extern jmethodID HttpResponse_getCode;
extern jmethodID HttpResponse_getContentType;
extern jmethodID IRemoteNodeClient_cancelAll;
extern jmethodID IRemoteNodeClient_makeRequest;
extern jmethodID Logger_logFromNative;
extern jmethodID OwnedTxOut_ctor;
extern jmethodID Wallet_onRefresh;
extern jmethodID WalletNative_callRemoteNode;
extern jmethodID WalletNative_onRefresh;
extern jmethodID WalletNative_onSuspendRefresh;
// android.os
extern jmethodID ParcelFileDescriptor_detachFd;

View File

@ -24,22 +24,21 @@ static_assert(CRYPTONOTE_MAX_BLOCK_NUMBER == 500000000,
Wallet::Wallet(
JNIEnv* env,
int network_id,
const JvmRef<jobject>& remote_node_client,
const JvmRef<jobject>& callback)
const JvmRef<jobject>& wallet_native)
: m_wallet(static_cast<cryptonote::network_type>(network_id),
0, /* kdf_rounds */
true, /* unattended */
std::make_unique<RemoteNodeClientFactory>(env, remote_node_client)),
m_remote_node_client(env, remote_node_client),
m_callback(env, callback),
std::make_unique<RemoteNodeClientFactory>(env, wallet_native)),
m_callback(env, wallet_native),
m_account_ready(false),
m_blockchain_height(1),
m_restore_height(0),
m_refresh_running(false),
m_refresh_stopped(false) {
m_refresh_canceled(false) {
// Use a bogus ipv6 address as a placeholder for the daemon address.
LOG_FATAL_IF(!m_wallet.init("[100::/64]", {}, {}, 0, false),
"Init failed");
m_wallet.stop();
m_wallet.callback(this);
}
@ -95,7 +94,7 @@ bool Wallet::parseFrom(std::istream& input) {
}
bool Wallet::writeTo(std::ostream& output) {
return pauseRefreshAndRunLocked([&]() -> bool {
return suspendRefreshAndRunLocked([&]() -> bool {
binary_archive<true> ar(output);
if (!serialization::serialize_noeof(ar, *this))
return false;
@ -132,87 +131,85 @@ void Wallet::handleBalanceChanged(uint64_t at_block_height) {
m_tx_outs_mutex.unlock();
m_blockchain_height = at_block_height;
JNIEnv* env = getJniEnv();
m_callback.callVoidMethod(env, Wallet_onRefresh, at_block_height, true);
m_callback.callVoidMethod(env, WalletNative_onRefresh, at_block_height, true);
}
void Wallet::handleNewBlock(uint64_t height) {
void Wallet::handleNewBlock(uint64_t height, bool debounce) {
m_blockchain_height = height;
// Notify blockchain height every one second.
static std::chrono::steady_clock::time_point last_time;
auto now = std::chrono::steady_clock::now();
if (now - last_time >= 1.s) {
// Notify blockchain height every one second.
if (!debounce || (now - last_time >= 1.s)) {
last_time = now;
m_callback.callVoidMethod(getJniEnv(), Wallet_onRefresh, height, false);
m_callback.callVoidMethod(getJniEnv(), WalletNative_onRefresh, height, false);
}
}
Wallet::Status Wallet::refreshLoopUntilSynced(bool skip_coinbase) {
Wallet::Status Wallet::nonReentrantRefresh(bool skip_coinbase) {
LOG_FATAL_IF(m_refresh_running.exchange(true),
"Refresh should not be called concurrently");
Status ret;
std::unique_lock<std::mutex> refresh_lock(m_refresh_mutex);
m_refresh_stopped = false;
for (;;) {
std::lock_guard<std::mutex> wallet_lock(m_wallet_mutex);
if (m_refresh_stopped) {
ret = Status::INTERRUPTED;
break;
}
m_refresh_running = true;
m_wallet.set_refresh_type(skip_coinbase ? tools::wallet2::RefreshType::RefreshNoCoinbase
: tools::wallet2::RefreshType::RefreshDefault);
std::unique_lock<std::mutex> wallet_lock(m_wallet_mutex);
m_wallet.set_refresh_type(skip_coinbase ? tools::wallet2::RefreshType::RefreshNoCoinbase
: tools::wallet2::RefreshType::RefreshDefault);
while (!m_refresh_canceled) {
m_wallet.set_refresh_from_block_height(m_restore_height);
try {
// Calling refresh() will block until stop() is called or it sync successfully.
// refresh() will block until stop() is called or it syncs successfully.
m_wallet.refresh(false);
if (!m_wallet.stopped()) {
m_wallet.stop();
ret = Status::OK;
break;
}
} catch (const tools::error::no_connection_to_daemon&) {
m_refresh_running = false;
ret = Status::NO_NETWORK_CONNECTIVITY;
break;
} catch (const tools::error::refresh_error) {
m_refresh_running = false;
} catch (const tools::error::refresh_error&) {
ret = Status::REFRESH_ERROR;
break;
}
m_refresh_running = false;
if (!m_wallet.stopped()) {
ret = Status::OK;
break;
}
m_refresh_cond.wait(refresh_lock);
m_refresh_cond.wait(wallet_lock);
}
refresh_lock.unlock();
if (m_refresh_canceled) {
m_refresh_canceled = false;
ret = Status::INTERRUPTED;
}
m_refresh_running.store(false);
// Always notify the last block height.
m_callback.callVoidMethod(getJniEnv(), Wallet_onRefresh, m_blockchain_height, false);
handleNewBlock(m_blockchain_height, false);
return ret;
}
template<typename T>
auto Wallet::pauseRefreshAndRunLocked(T block) -> decltype(block()) {
std::unique_lock<std::mutex> refresh_lock(m_refresh_mutex, std::try_to_lock);
if (!refresh_lock.owns_lock()) {
auto Wallet::suspendRefreshAndRunLocked(T block) -> decltype(block()) {
std::unique_lock<std::mutex> wallet_lock(m_wallet_mutex, std::try_to_lock);
if (!wallet_lock.owns_lock()) {
JNIEnv* env = getJniEnv();
do {
if (refresh_is_running()) {
for (;;) {
if (!m_wallet.stopped()) {
m_wallet.stop();
m_remote_node_client.callVoidMethod(env, IRemoteNodeClient_cancelAll);
m_callback.callVoidMethod(env, WalletNative_onSuspendRefresh, true);
}
if (wallet_lock.try_lock()) {
break;
}
std::this_thread::yield();
} while (!refresh_lock.try_lock());
}
m_callback.callVoidMethod(env, WalletNative_onSuspendRefresh, false);
m_refresh_cond.notify_one();
}
LOG_FATAL_IF(refresh_is_running());
std::lock_guard<std::mutex> wallet_lock(m_wallet_mutex);
m_refresh_mutex.unlock();
m_refresh_cond.notify_one();
return block();
}
void Wallet::stopRefresh() {
pauseRefreshAndRunLocked([&]() {
m_refresh_stopped = true;
void Wallet::cancelRefresh() {
suspendRefreshAndRunLocked([&]() {
m_refresh_canceled = true;
});
}
void Wallet::setRefreshSince(long height_or_timestamp) {
pauseRefreshAndRunLocked([&]() {
suspendRefreshAndRunLocked([&]() {
if (height_or_timestamp < CRYPTONOTE_MAX_BLOCK_NUMBER) {
m_restore_height = height_or_timestamp;
} else {
@ -226,11 +223,8 @@ JNIEXPORT jlong JNICALL
Java_im_molly_monero_WalletNative_nativeCreate(
JNIEnv* env,
jobject thiz,
jint network_id,
jobject p_remote_node_client) {
auto wallet = new Wallet(env, network_id,
JvmParamRef<jobject>(p_remote_node_client),
JvmParamRef<jobject>(thiz));
jint network_id) {
auto wallet = new Wallet(env, network_id, JvmParamRef<jobject>(thiz));
return nativeToJvmPointer(wallet);
}
@ -284,33 +278,23 @@ Java_im_molly_monero_WalletNative_nativeSave(
extern "C"
JNIEXPORT jint JNICALL
Java_im_molly_monero_WalletNative_nativeRefreshLoopUntilSynced
Java_im_molly_monero_WalletNative_nativeNonReentrantRefresh
(JNIEnv* env,
jobject thiz,
jlong handle,
jboolean skip_coinbase) {
auto* wallet = reinterpret_cast<Wallet*>(handle);
return wallet->refreshLoopUntilSynced(skip_coinbase);
return wallet->nonReentrantRefresh(skip_coinbase);
}
extern "C"
JNIEXPORT void JNICALL
Java_im_molly_monero_WalletNative_nativeStopRefresh(
Java_im_molly_monero_WalletNative_nativeCancelRefresh(
JNIEnv* env,
jobject thiz,
jlong handle) {
auto* wallet = reinterpret_cast<Wallet*>(handle);
wallet->stopRefresh();
}
extern "C"
JNIEXPORT jboolean JNICALL
Java_im_molly_monero_WalletNative_nativeRefreshIsRunning(
JNIEnv* env,
jobject thiz,
jlong handle) {
auto* wallet = reinterpret_cast<Wallet*>(handle);
return wallet->refresh_is_running();
wallet->cancelRefresh();
}
extern "C"

View File

@ -24,8 +24,7 @@ class Wallet : tools::i_wallet2_callback {
Wallet(JNIEnv* env,
int network_id,
const JvmRef<jobject>& remote_node_client,
const JvmRef<jobject>& callback);
const JvmRef<jobject>& wallet_native);
void restoreAccount(const std::vector<char>& secret_scalar, uint64_t account_timestamp);
uint64_t estimateRestoreHeight(uint64_t timestamp);
@ -33,8 +32,8 @@ class Wallet : tools::i_wallet2_callback {
bool parseFrom(std::istream& input);
bool writeTo(std::ostream& output);
Wallet::Status refreshLoopUntilSynced(bool skip_coinbase);
void stopRefresh();
Wallet::Status nonReentrantRefresh(bool skip_coinbase);
void cancelRefresh();
void setRefreshSince(long height_or_timestamp);
template<typename Callback>
@ -44,9 +43,7 @@ class Wallet : tools::i_wallet2_callback {
uint64_t current_blockchain_height() const { return m_blockchain_height; }
bool refresh_is_running() const { return m_refresh_running; }
// Extra object's state that need to be persistent.
// Extra state that must be persistent and isn't restored by wallet2's serializer.
BEGIN_SERIALIZE_OBJECT()
VERSION_FIELD(0)
FIELD(m_restore_height)
@ -67,19 +64,18 @@ class Wallet : tools::i_wallet2_callback {
std::mutex m_tx_outs_mutex;
std::mutex m_refresh_mutex;
// Reference to Kotlin instances.
const ScopedJvmGlobalRef<jobject> m_remote_node_client;
// Reference to Kotlin wallet instance.
const ScopedJvmGlobalRef<jobject> m_callback;
std::condition_variable m_refresh_cond;
bool m_refresh_running;
bool m_refresh_stopped;
std::atomic<bool> m_refresh_running;
bool m_refresh_canceled;
template<typename T>
auto pauseRefreshAndRunLocked(T block) -> decltype(block());
auto suspendRefreshAndRunLocked(T block) -> decltype(block());
void handleBalanceChanged(uint64_t at_block_height);
void handleNewBlock(uint64_t height);
void handleNewBlock(uint64_t height, bool debounce = true);
// Implementation of i_wallet2_callback follows.
private:

View File

@ -0,0 +1,33 @@
package im.molly.monero
data class HttpRequest(
val method: String?,
val path: String?,
val header: String?,
val bodyBytes: ByteArray?,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as HttpRequest
if (method != other.method) return false
if (path != other.path) return false
if (header != other.header) return false
if (bodyBytes != null) {
if (other.bodyBytes == null) return false
if (!bodyBytes.contentEquals(other.bodyBytes)) return false
} else if (other.bodyBytes != null) return false
return true
}
override fun hashCode(): Int {
var result = method?.hashCode() ?: 0
result = 31 * result + (path?.hashCode() ?: 0)
result = 31 * result + (header?.hashCode() ?: 0)
result = 31 * result + (bodyBytes?.contentHashCode() ?: 0)
return result
}
}

View File

@ -1,14 +1,13 @@
package im.molly.monero
import android.os.ParcelFileDescriptor
import android.os.Parcelable
import kotlinx.parcelize.Parcelize
@Parcelize
data class HttpResponse
@CalledByNative("http_client.cc")
constructor(
data class HttpResponse(
val code: Int,
val contentType: String? = null,
val body: ParcelFileDescriptor? = null,
) : Parcelable
) : AutoCloseable {
override fun close() {
body?.close()
}
}

View File

@ -7,7 +7,10 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.suspendCancellableCoroutine
class MoneroWallet(private val wallet: IWallet) : IWallet by wallet, AutoCloseable {
class MoneroWallet internal constructor(
private val wallet: IWallet,
client: RemoteNodeClient?,
) : IWallet by wallet, AutoCloseable {
val publicAddress: String = wallet.primaryAccountAddress
@ -44,9 +47,9 @@ class MoneroWallet(private val wallet: IWallet) : IWallet by wallet, AutoCloseab
}
}
restartRefresh(skipCoinbaseOutputs, callback)
resumeRefresh(skipCoinbaseOutputs, callback)
continuation.invokeOnCancellation { stopRefresh() }
continuation.invokeOnCancellation { cancelRefresh() }
}
}

View File

@ -3,7 +3,6 @@ package im.molly.monero
import android.net.Uri
import android.os.Parcelable
import kotlinx.parcelize.Parcelize
import java.io.IOException
@Parcelize
data class RemoteNode(
@ -14,18 +13,3 @@ data class RemoteNode(
fun uriForPath(path: String): Uri =
uri.buildUpon().appendPath(path.trimStart('/')).build()
}
interface RemoteNodeSelector {
/**
* Selects an appropriate remote node to access the Monero network.
*/
fun select(): RemoteNode?
/**
* Called to indicate that a connection could not be established to a remote node.
*
* An implementation of this method can temporarily remove the node or reorder the sequence
* of nodes returned by [select].
*/
fun connectFailed(remoteNode: RemoteNode, ioe: IOException)
}

View File

@ -2,141 +2,236 @@ package im.molly.monero
import android.net.Uri
import android.os.ParcelFileDescriptor
import im.molly.monero.loadbalancer.LoadBalancer
import im.molly.monero.loadbalancer.Rule
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow
import okhttp3.*
import java.io.FileOutputStream
import java.io.IOException
import java.util.concurrent.ConcurrentHashMap
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine
internal class RemoteNodeClient(
private val nodeSelector: RemoteNodeSelector,
// TODO: Hide IRemoteNodeClient methods
class RemoteNodeClient private constructor(
val network: MoneroNetwork,
private val loadBalancer: LoadBalancer,
private val loadBalancerRule: Rule,
private val httpClient: OkHttpClient,
ioDispatcher: CoroutineDispatcher,
) : IRemoteNodeClient.Stub() {
private val retryBackoff: BackoffPolicy,
private val requestsScope: CoroutineScope,
) : IRemoteNodeClient.Stub(), AutoCloseable {
companion object {
/**
* Constructs a [RemoteNodeClient] to connect to the Monero [network].
*/
fun forNetwork(
network: MoneroNetwork,
remoteNodes: Flow<List<RemoteNode>>,
loadBalancerRule: Rule,
httpClient: OkHttpClient,
retryBackoff: BackoffPolicy = ExponentialBackoff.Default,
ioDispatcher: CoroutineDispatcher = Dispatchers.IO,
): RemoteNodeClient {
val scope = CoroutineScope(ioDispatcher + SupervisorJob())
return RemoteNodeClient(
network,
LoadBalancer(remoteNodes, scope),
loadBalancerRule,
httpClient,
retryBackoff,
scope
)
}
}
private val logger = loggerFor<RemoteNodeClient>()
private val requestsScope = CoroutineScope(ioDispatcher + SupervisorJob())
private val requestList = ConcurrentHashMap<Int, Job>()
// /** Disable connecting to the Monero network */
// var offline = false
private fun selectedNode() = nodeSelector.select()
@CalledByNative("http_client.cc")
override fun makeRequest(
override fun requestAsync(
requestId: Int,
method: String?,
path: String?,
header: String?,
body: ByteArray?,
): HttpResponse? {
val selected = selectedNode()
if (selected == null) {
logger.w("No remote node selected")
return null
callback: IHttpRequestCallback?,
) {
requireNotNull(path)
requireNotNull(method)
logger.d("HTTP: $method $path, header_len=${header?.length}, body_size=${body?.size}")
val requestJob = requestsScope.launch {
runCatching {
requestWithRetry(method, path, header, body)
}.onSuccess { response ->
val statusCode = response.code()
val responseBody = response.body()
if (responseBody == null) {
callback?.onResponse(statusCode, null, null)
} else {
val contentType = responseBody.contentType()?.toString()
val pipe = ParcelFileDescriptor.createPipe()
callback?.onResponse(statusCode, contentType, pipe[0])
responseBody.use {
pipe[1].use { writeSide ->
FileOutputStream(writeSide.fileDescriptor).use { out ->
runCatching { it.byteStream().copyTo(out) }
}
}
}
}
// TODO: Log response times
}.onFailure { throwable ->
logger.e("HTTP: Request failed", throwable)
callback?.onFailure()
}
}.also {
requestList[requestId] = it
}
val uri = selected.uriForPath(path ?: "")
return try {
execute(method, uri, header, body, selected.username, selected.password)
} catch (ioe: IOException) {
logger.e("HTTP: Request failed", ioe)
return null
} catch (e: IllegalArgumentException) {
logger.e("HTTP: Bad request", e)
return null
requestJob.invokeOnCompletion {
requestList.remove(requestId)
}
}
@CalledByNative("http_client.cc")
override fun cancelAll() {
requestsScope.coroutineContext.cancelChildren()
override fun cancelAsync(requestId: Int) {
requestList[requestId]?.cancel()
}
private fun execute(
method: String?,
uri: Uri,
override fun close() {
requestsScope.cancel()
}
private suspend fun requestWithRetry(
method: String,
path: String,
header: String?,
body: ByteArray?,
): Response {
val attempts = mutableMapOf<Uri, Int>()
while (true) {
val selected = loadBalancerRule.chooseNode(loadBalancer)
if (selected == null) {
logger.i("No remote node available")
return Response.Builder().code(499).build()
}
val uri = selected.uriForPath(path)
val retryCount = attempts[uri] ?: 0
delay(retryBackoff.waitTime(retryCount))
val response = try {
executeCall(
method = method,
uri = uri,
username = selected.username,
password = selected.password,
header = header,
body = body,
)
} catch (e: IOException) {
logger.e("HTTP: Request failed", e)
// TODO: Notify loadBalancer
continue
} finally {
attempts[uri] = retryCount + 1
}
if (response.isSuccessful) {
// TODO: Notify loadBalancer
return response
}
}
}
private suspend fun executeCall(
method: String?,
uri: Uri,
username: String?,
password: String?,
): HttpResponse {
logger.d("HTTP: $method $uri, header_len=${header?.length}, body_size=${body?.size}")
val headers = header?.parseHttpHeader()
val contentType = headers?.get("Content-Type")?.let { value ->
header: String?,
body: ByteArray?,
): Response {
val headers = parseHttpHeader(header)
val contentType = headers.get("Content-Type")?.let { value ->
MediaType.get(value)
}
// TODO: Log unsupported headers
val request = with(Request.Builder()) {
when (method) {
"GET" -> {}
"POST" -> post(RequestBody.create(contentType, body ?: ByteArray(0)))
else -> {
throw IllegalArgumentException("Unsupported method")
when {
method.equals("GET", ignoreCase = true) -> {}
method.equals("POST", ignoreCase = true) -> {
val content = body ?: ByteArray(0)
post(RequestBody.create(contentType, content))
}
else -> throw IllegalArgumentException("Unsupported method")
}
// TODO: Add authentication
url(uri.toString())
build()
}
val response = runBlocking(requestsScope.coroutineContext) {
val call = httpClient.newCall(request)
try {
call.await()
} catch (ioe: IOException) {
if (!call.isCanceled) {
throw ioe
}
null
}
}
return if (response == null) {
HttpResponse(code = 499)
} else if (response.isSuccessful) {
val responseBody = requireNotNull(response.body())
val pipe = ParcelFileDescriptor.createPipe()
requestsScope.launch {
pipe[1].use { writeSide ->
FileOutputStream(writeSide.fileDescriptor).use { outputStream ->
responseBody.byteStream().copyTo(outputStream)
}
}
}
HttpResponse(
code = response.code(),
contentType = responseBody.contentType()?.toString(),
body = pipe[0],
)
} else {
HttpResponse(code = response.code())
}
return httpClient.newCall(request).await()
}
@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun Call.await() = suspendCancellableCoroutine { continuation ->
enqueue(object : Callback {
override fun onResponse(call: Call, response: Response) {
continuation.resume(response) {
response.close()
}
}
override fun onFailure(call: Call, e: IOException) {
continuation.resumeWithException(e)
}
})
continuation.invokeOnCancellation { cancel() }
}
fun String.parseHttpHeader(): Headers =
private fun parseHttpHeader(header: String?): Headers =
with(Headers.Builder()) {
splitToSequence("\r\n")
.filter { line -> line.isNotEmpty() }
.forEach { line ->
add(line)
}
header?.splitToSequence("\r\n")
?.filter { line -> line.isNotEmpty() }
?.forEach { line -> add(line) }
build()
}
private suspend fun Call.await() =
suspendCoroutine { continuation ->
enqueue(object : Callback {
override fun onResponse(call: Call, response: Response) {
continuation.resume(response)
}
override fun onFailure(call: Call, e: IOException) {
continuation.resumeWithException(e)
}
})
}
// private val Response.roundTripMillis: Long
// get() = sentRequestAtMillis() - receivedResponseAtMillis()
}
@OptIn(ExperimentalCoroutinesApi::class)
internal suspend fun IRemoteNodeClient.request(request: HttpRequest): HttpResponse? =
suspendCancellableCoroutine { continuation ->
val requestId = request.hashCode()
val callback = object : IHttpRequestCallback.Stub() {
override fun onResponse(
code: Int,
contentType: String?,
body: ParcelFileDescriptor?,
) {
continuation.resume(HttpResponse(code, contentType, body)) {
body?.close()
}
}
override fun onFailure() {
continuation.resume(null) {}
}
}
with(request) {
requestAsync(requestId, method, path, header, bodyBytes, callback)
}
continuation.invokeOnCancellation {
cancelAsync(requestId)
}
}

View File

@ -1,9 +0,0 @@
package im.molly.monero
enum class RemoteNodeState {
Online,
Connecting,
Unauthorized,
Disconnected,
Offline;
}

View File

@ -3,7 +3,6 @@ package im.molly.monero
import kotlin.math.pow
import kotlin.random.Random
import kotlin.time.Duration
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds
import kotlin.time.DurationUnit
import kotlin.time.toDuration
@ -22,10 +21,10 @@ interface BackoffPolicy {
* @param maxBackoff Set a hard maximum [Duration] for exponential backoff.
*/
class ExponentialBackoff(
private val minBackoff: Duration = 1.seconds,
private val maxBackoff: Duration = 20.seconds,
private val multiplier: Double = 1.6,
private val jitter: Double = 0.2,
private val minBackoff: Duration,
private val maxBackoff: Duration,
private val multiplier: Double,
private val jitter: Double,
) : BackoffPolicy {
init {
require(minBackoff.isPositive())
@ -46,4 +45,13 @@ class ExponentialBackoff(
val jitter = Random.nextDouble(-jitterAmount, jitterAmount)
return waitTime + jitter.toDuration(DurationUnit.MILLISECONDS)
}
companion object {
val Default = ExponentialBackoff(
minBackoff = 1.seconds,
maxBackoff = 20.seconds,
multiplier = 1.6,
jitter = 0.2,
)
}
}

View File

@ -4,18 +4,20 @@ import android.os.ParcelFileDescriptor
import androidx.annotation.GuardedBy
import kotlinx.coroutines.*
import java.io.Closeable
import java.util.*
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.coroutines.CoroutineContext
class WalletNative private constructor(
networkId: Int,
remoteNodeClient: IRemoteNodeClient?,
private val remoteNodeClient: IRemoteNodeClient?,
private val scope: CoroutineScope,
private val ioDispatcher: CoroutineDispatcher,
) : IWallet.Stub(), Closeable {
companion object {
// TODO: Full node wallet != local synchronization wallet
fun fullNode(
networkId: Int,
secretSpendKey: SecretKey? = null,
@ -53,8 +55,7 @@ class WalletNative private constructor(
MoneroJni.loadLibrary(logger = logger)
}
private val handle: Long =
nativeCreate(networkId, remoteNodeClient ?: IRemoteNodeClient.Default())
private val handle: Long = nativeCreate(networkId)
override fun getPrimaryAccountAddress() = nativeGetPrimaryAccountAddress(handle)
@ -71,26 +72,28 @@ class WalletNative private constructor(
private val balanceListenersLock = ReentrantLock()
private val refreshDispatcher = ioDispatcher.limitedParallelism(1)
@OptIn(ExperimentalCoroutinesApi::class)
private val singleThreadedDispatcher = ioDispatcher.limitedParallelism(1)
override fun restartRefresh(
@OptIn(ExperimentalCoroutinesApi::class)
override fun resumeRefresh(
skipCoinbaseOutputs: Boolean,
callback: IRefreshCallback?,
) {
if (nativeRefreshIsRunning(handle)) {
nativeStopRefresh(handle)
}
val refreshJob = scope.launch(refreshDispatcher) {
val status = nativeRefreshLoopUntilSynced(handle, skipCoinbaseOutputs)
scope.launch {
val status = suspendCancellableCoroutine { continuation ->
launch(singleThreadedDispatcher) {
continuation.resume(nativeNonReentrantRefresh(handle, skipCoinbaseOutputs)) {}
}
continuation.invokeOnCancellation {
nativeCancelRefresh(handle)
}
}
callback?.onResult(currentBlockchainHeight, status)
}
// Spin until the refresh thread enters in a cancellable state
while (refreshJob.isActive && !nativeRefreshIsRunning(handle)) {
Thread.yield()
}
}
override fun stopRefresh() = nativeStopRefresh(handle)
override fun cancelRefresh() = nativeCancelRefresh(handle)
override fun setRefreshSince(blockHeightOrTimestamp: Long) {
nativeSetRefreshSince(handle, blockHeightOrTimestamp)
@ -136,6 +139,48 @@ class WalletNative private constructor(
}
}
@CalledByNative("wallet.cc")
private fun onSuspendRefresh(suspending: Boolean) {
if (suspending) {
pendingRequestLock.withLock {
pendingRequest?.cancel()
requestsAllowed = false
}
} else {
requestsAllowed = true
}
}
private var requestsAllowed = true
@GuardedBy("pendingRequestLock")
private var pendingRequest: Deferred<HttpResponse?>? = null
private val pendingRequestLock = ReentrantLock()
@CalledByNative("wallet.cc")
private fun callRemoteNode(
method: String?,
path: String?,
header: String?,
body: ByteArray?,
): HttpResponse? = runBlocking {
pendingRequestLock.withLock {
pendingRequest = if (requestsAllowed) {
async {
remoteNodeClient?.request(HttpRequest(method, path, header, body))
}
} else null
}
runCatching {
pendingRequest?.await()
}.onFailure { throwable ->
if (throwable !is CancellationException) {
throw throwable
}
}.getOrNull()
}
override fun close() {
scope.cancel()
}
@ -152,21 +197,18 @@ class WalletNative private constructor(
const val REFRESH_ERROR: Int = 3
}
private external fun nativeCreate(networkId: Int, remoteNodeClient: IRemoteNodeClient): Long
private external fun nativeCancelRefresh(handle: Long)
private external fun nativeCreate(networkId: Int): Long
private external fun nativeDispose(handle: Long)
private external fun nativeGetCurrentBlockchainHeight(handle: Long): Long
private external fun nativeGetOwnedTxOuts(handle: Long): Array<OwnedTxOut>
private external fun nativeGetPrimaryAccountAddress(handle: Long): String
private external fun nativeLoad(handle: Long, fd: Int): Boolean
private external fun nativeRefreshIsRunning(handle: Long): Boolean
private external fun nativeRefreshLoopUntilSynced(handle: Long, skipCoinbase: Boolean): Int
private external fun nativeNonReentrantRefresh(handle: Long, skipCoinbase: Boolean): Int
private external fun nativeRestoreAccount(
handle: Long,
secretScalar: ByteArray,
accountTimestamp: Long
handle: Long, secretScalar: ByteArray, accountTimestamp: Long
)
private external fun nativeSave(handle: Long, fd: Int): Boolean
private external fun nativeSetRefreshSince(handle: Long, heightOrTimestamp: Long)
private external fun nativeStopRefresh(handle: Long)
}

View File

@ -7,46 +7,22 @@ import android.content.ServiceConnection
import android.os.IBinder
import android.os.ParcelFileDescriptor
import kotlinx.coroutines.*
import okhttp3.OkHttpClient
import java.io.FileInputStream
import java.io.FileOutputStream
import java.time.Instant
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
// TODO: Rename to SandboxedWalletClient and extract interface, add InProcessWalletClient
class WalletClient private constructor(
// TODO: Rename to SandboxedWalletProvider and extract interface, add InProcessWalletProvider
class WalletProvider private constructor(
private val context: Context,
private val service: IWalletService,
private val serviceConnection: ServiceConnection,
private val network: MoneroNetwork,
private val remoteNodeClient: RemoteNodeClient?,
// TODO: Remove DataStore dependencies if unused
// private val dataStore: DataStore<WalletProto.State>,
) : AutoCloseable {
private val logger = loggerFor<WalletClient>()
) {
companion object {
/**
* Constructs a [WalletClient] to connect to the Monero network [network].
*
* @param context Calling application's [Context].
* @throws [ServiceNotBoundException] if the wallet service can not be bound.
*/
suspend fun forNetwork(
context: Context,
network: MoneroNetwork,
nodeSelector: RemoteNodeSelector? = null,
httpClient: OkHttpClient? = null,
ioDispatcher: CoroutineDispatcher = Dispatchers.IO,
): WalletClient {
val remoteNodeClient = nodeSelector?.let {
requireNotNull(httpClient)
RemoteNodeClient(it, httpClient, ioDispatcher)
}
suspend fun connect(context: Context): WalletProvider {
val (serviceConnection, service) = bindService(context)
return WalletClient(context, service, serviceConnection, network, remoteNodeClient)
return WalletProvider(context, service, serviceConnection)
}
private suspend fun bindService(context: Context): Pair<ServiceConnection, IWalletService> {
@ -71,25 +47,46 @@ class WalletClient private constructor(
}
}
/** Exception thrown by [WalletClient] if the remote service can't be bound. */
/** Exception thrown by [WalletProvider] if the remote service can't be bound. */
class ServiceNotBoundException : Exception()
fun createNewWallet(): MoneroWallet =
MoneroWallet(service.createWallet(buildConfig(), remoteNodeClient))
private val logger = loggerFor<WalletProvider>()
fun restoreWallet(secretSpendKey: SecretKey, accountCreationTime: Instant): MoneroWallet =
MoneroWallet(
fun createNewWallet(
network: MoneroNetwork,
client: RemoteNodeClient? = null,
): MoneroWallet {
require(client == null || client.network == network)
return MoneroWallet(
service.createWallet(buildConfig(network), client), client
)
}
fun restoreWallet(
network: MoneroNetwork,
client: RemoteNodeClient? = null,
secretSpendKey: SecretKey,
accountCreationTime: Instant,
): MoneroWallet {
require(client == null || client.network == network)
return MoneroWallet(
service.restoreWallet(
buildConfig(),
remoteNodeClient,
buildConfig(network),
client,
secretSpendKey,
accountCreationTime.epochSecond,
)
),
client,
)
}
fun openWallet(source: FileInputStream): MoneroWallet =
ParcelFileDescriptor.dup(source.fd).use {
MoneroWallet(service.openWallet(buildConfig(), remoteNodeClient, it))
fun openWallet(
network: MoneroNetwork,
client: RemoteNodeClient? = null,
source: FileInputStream,
): MoneroWallet =
ParcelFileDescriptor.dup(source.fd).use { fd ->
MoneroWallet(service.openWallet(buildConfig(network), client, fd), client)
}
fun saveWallet(wallet: MoneroWallet, destination: FileOutputStream) {
@ -98,11 +95,9 @@ class WalletClient private constructor(
}
}
private fun buildConfig() = WalletConfig(network.id)
private fun buildConfig(network: MoneroNetwork) = WalletConfig(network.id)
// private fun <R> callRemoteService(task: () -> R): R = task()
override fun close() {
fun disconnect() {
context.unbindService(serviceConnection)
}
}

View File

@ -0,0 +1,53 @@
package im.molly.monero.loadbalancer
import im.molly.monero.RemoteNode
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.launch
import kotlin.time.Duration
class LoadBalancer(
remoteNodes: Flow<List<RemoteNode>>,
scope: CoroutineScope,
) {
var onlineNodes: List<RemoteNode> = emptyList()
init {
scope.launch {
remoteNodes.collect {
updateNodes(it)
}
}
}
private fun updateNodes(nodeList: List<RemoteNode>) {
onlineNodes = nodeList
}
fun onResponseTimeObservation(remoteNode: RemoteNode, responseTime: Duration) {
// TODO
}
}
sealed interface RemoteNodeState {
/**
* The remote node is currently online and able to handle requests.
*/
data class Online(val responseTime: Duration) : RemoteNodeState
/**
* The client's request has timed out and no response has been received.
*/
data class Timeout(val cause: Throwable?)
/**
* Indicates that an error occurred while processing the client's request to the remote node.
*/
// open data class Error(val message: String?) : RemoteNodeState {
/**
* Indicates that the client is unauthorized to access the remote node, i.e. the client's credentials were invalid.
*/
// data class Unauthorized(override val message: String?) : Error
}

View File

@ -0,0 +1,26 @@
package im.molly.monero.loadbalancer
import im.molly.monero.RemoteNode
import java.util.concurrent.atomic.AtomicInteger
interface Rule {
/**
* Returns one alive [RemoteNode] from the [loadBalancer] using its internal
* node selection rule, or null if none are available.
*/
fun chooseNode(loadBalancer: LoadBalancer): RemoteNode?
}
class RoundRobinRule : Rule {
private var currentIndex = AtomicInteger(0)
override fun chooseNode(loadBalancer: LoadBalancer): RemoteNode? {
val nodes = loadBalancer.onlineNodes
return if (nodes.isNotEmpty()) {
val index = currentIndex.getAndIncrement() % nodes.size
nodes[index]
} else {
null
}
}
}