lib: add method to cancel in-flight requests

This commit is contained in:
Oscar Mira 2023-02-26 13:20:38 +01:00
parent 47efd85fd6
commit 56ecb91657
7 changed files with 65 additions and 29 deletions

View File

@ -6,4 +6,5 @@ import im.molly.monero.RemoteNode;
interface IRemoteNodeClient { interface IRemoteNodeClient {
// RemoteNode getRemoteNode(); // RemoteNode getRemoteNode();
HttpResponse makeRequest(String method, String path, String header, in byte[] body); HttpResponse makeRequest(String method, String path, String header, in byte[] body);
oneway void cancelAll();
} }

View File

@ -7,6 +7,7 @@ ScopedJvmGlobalRef<jclass> OwnedTxOut;
jmethodID HttpResponse_getBody; jmethodID HttpResponse_getBody;
jmethodID HttpResponse_getCode; jmethodID HttpResponse_getCode;
jmethodID HttpResponse_getContentType; jmethodID HttpResponse_getContentType;
jmethodID IRemoteNodeClient_cancelAll;
jmethodID IRemoteNodeClient_makeRequest; jmethodID IRemoteNodeClient_makeRequest;
jmethodID Logger_logFromNative; jmethodID Logger_logFromNative;
jmethodID OwnedTxOut_ctor; jmethodID OwnedTxOut_ctor;
@ -29,6 +30,8 @@ void initializeJniCache(JNIEnv* env) {
.getMethodId(env, "getCode", "()I"); .getMethodId(env, "getCode", "()I");
HttpResponse_getContentType = httpResponse HttpResponse_getContentType = httpResponse
.getMethodId(env, "getContentType", "()Ljava/lang/String;"); .getMethodId(env, "getContentType", "()Ljava/lang/String;");
IRemoteNodeClient_cancelAll = iRemoteNodeClient
.getMethodId(env, "cancelAll", "()V");
IRemoteNodeClient_makeRequest = iRemoteNodeClient IRemoteNodeClient_makeRequest = iRemoteNodeClient
.getMethodId(env, .getMethodId(env,
"makeRequest", "makeRequest",

View File

@ -13,6 +13,7 @@ extern ScopedJvmGlobalRef<jclass> OwnedTxOut;
extern jmethodID HttpResponse_getBody; extern jmethodID HttpResponse_getBody;
extern jmethodID HttpResponse_getCode; extern jmethodID HttpResponse_getCode;
extern jmethodID HttpResponse_getContentType; extern jmethodID HttpResponse_getContentType;
extern jmethodID IRemoteNodeClient_cancelAll;
extern jmethodID IRemoteNodeClient_makeRequest; extern jmethodID IRemoteNodeClient_makeRequest;
extern jmethodID Logger_logFromNative; extern jmethodID Logger_logFromNative;
extern jmethodID OwnedTxOut_ctor; extern jmethodID OwnedTxOut_ctor;

View File

@ -24,12 +24,13 @@ static_assert(CRYPTONOTE_MAX_BLOCK_NUMBER == 500000000,
Wallet::Wallet( Wallet::Wallet(
JNIEnv* env, JNIEnv* env,
int network_id, int network_id,
std::unique_ptr<HttpClientFactory> http_client_factory, const JvmRef<jobject>& remote_node_client,
const JvmRef<jobject>& callback) const JvmRef<jobject>& callback)
: m_wallet(static_cast<cryptonote::network_type>(network_id), : m_wallet(static_cast<cryptonote::network_type>(network_id),
0, /* kdf_rounds */ 0, /* kdf_rounds */
true, /* unattended */ true, /* unattended */
std::move(http_client_factory)), std::make_unique<RemoteNodeClientFactory>(env, remote_node_client)),
m_remote_node_client(env, remote_node_client),
m_callback(env, callback), m_callback(env, callback),
m_account_ready(false), m_account_ready(false),
m_blockchain_height(1), m_blockchain_height(1),
@ -212,10 +213,8 @@ Java_im_molly_monero_WalletNative_nativeCreate(
jobject thiz, jobject thiz,
jint network_id, jint network_id,
jobject p_remote_node_client) { jobject p_remote_node_client) {
auto wallet = new Wallet( auto wallet = new Wallet(env, network_id,
env, network_id, JvmParamRef<jobject>(p_remote_node_client),
std::make_unique<RemoteNodeClientFactory>(
env, JvmParamRef<jobject>(p_remote_node_client)),
JvmParamRef<jobject>(thiz)); JvmParamRef<jobject>(thiz));
return nativeToJvmPointer(wallet); return nativeToJvmPointer(wallet);
} }

View File

@ -24,7 +24,7 @@ class Wallet : tools::i_wallet2_callback {
Wallet(JNIEnv* env, Wallet(JNIEnv* env,
int network_id, int network_id,
std::unique_ptr<HttpClientFactory> http_client_factory, const JvmRef<jobject>& remote_node_client,
const JvmRef<jobject>& callback); const JvmRef<jobject>& callback);
void restoreAccount(const std::vector<char>& secret_scalar, uint64_t account_timestamp); void restoreAccount(const std::vector<char>& secret_scalar, uint64_t account_timestamp);

View File

@ -2,23 +2,23 @@ package im.molly.monero
import android.net.Uri import android.net.Uri
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import okhttp3.* import okhttp3.*
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.IOException import java.io.IOException
import kotlin.coroutines.resumeWithException
internal class RemoteNodeClient( internal class RemoteNodeClient(
private val nodeSelector: RemoteNodeSelector, private val nodeSelector: RemoteNodeSelector,
private val httpClient: OkHttpClient, private val httpClient: OkHttpClient,
private val scope: CoroutineScope, ioDispatcher: CoroutineDispatcher,
private val ioDispatcher: CoroutineDispatcher,
) : IRemoteNodeClient.Stub() { ) : IRemoteNodeClient.Stub() {
private val logger = loggerFor<RemoteNodeClient>() private val logger = loggerFor<RemoteNodeClient>()
/** Disable connecting to the Monero network */ private val requestsScope = CoroutineScope(ioDispatcher + SupervisorJob())
// /** Disable connecting to the Monero network */
// var offline = false // var offline = false
private fun selectedNode() = nodeSelector.select() private fun selectedNode() = nodeSelector.select()
@ -47,6 +47,11 @@ internal class RemoteNodeClient(
} }
} }
@CalledByNative("http_client.cc")
override fun cancelAll() {
requestsScope.coroutineContext.cancelChildren()
}
private fun execute( private fun execute(
method: String?, method: String?,
uri: Uri, uri: Uri,
@ -56,10 +61,12 @@ internal class RemoteNodeClient(
password: String?, password: String?,
): HttpResponse { ): HttpResponse {
logger.d("HTTP: $method $uri, header_len=${header?.length}, body_size=${body?.size}") logger.d("HTTP: $method $uri, header_len=${header?.length}, body_size=${body?.size}")
val headers = header?.parseHttpHeader() val headers = header?.parseHttpHeader()
val contentType = headers?.get("Content-Type")?.let { value -> val contentType = headers?.get("Content-Type")?.let { value ->
MediaType.get(value) MediaType.get(value)
} }
val request = with(Request.Builder()) { val request = with(Request.Builder()) {
when (method) { when (method) {
"GET" -> {} "GET" -> {}
@ -71,16 +78,28 @@ internal class RemoteNodeClient(
url(uri.toString()) url(uri.toString())
build() build()
} }
val response = httpClient.newCall(request).execute()
return if (response.isSuccessful) { val response = runBlocking(requestsScope.coroutineContext) {
val call = httpClient.newCall(request)
try {
call.await()
} catch (ioe: IOException) {
if (!call.isCanceled) {
throw ioe
}
null
}
}
return if (response == null) {
HttpResponse(code = 499)
} else if (response.isSuccessful) {
val responseBody = requireNotNull(response.body()) val responseBody = requireNotNull(response.body())
val pipe = ParcelFileDescriptor.createPipe() val pipe = ParcelFileDescriptor.createPipe()
scope.launch(ioDispatcher) { requestsScope.launch {
pipe[1].use { readFd -> pipe[1].use { writeSide ->
FileOutputStream(readFd.fileDescriptor).use { outputStream -> FileOutputStream(writeSide.fileDescriptor).use { outputStream ->
responseBody.use { responseBody.byteStream().copyTo(outputStream)
it.byteStream().copyTo(outputStream)
}
} }
} }
} }
@ -90,13 +109,28 @@ internal class RemoteNodeClient(
body = pipe[0], body = pipe[0],
) )
} else { } else {
HttpResponse(code = response.code()).also { HttpResponse(code = response.code())
response.close()
}
} }
} }
private fun String.parseHttpHeader(): Headers = @OptIn(ExperimentalCoroutinesApi::class)
private suspend fun Call.await() = suspendCancellableCoroutine { continuation ->
enqueue(object : Callback {
override fun onResponse(call: Call, response: Response) {
continuation.resume(response) {
response.close()
}
}
override fun onFailure(call: Call, e: IOException) {
continuation.resumeWithException(e)
}
})
continuation.invokeOnCancellation { cancel() }
}
fun String.parseHttpHeader(): Headers =
with(Headers.Builder()) { with(Headers.Builder()) {
splitToSequence("\r\n") splitToSequence("\r\n")
.filter { line -> line.isNotEmpty() } .filter { line -> line.isNotEmpty() }

View File

@ -39,13 +39,11 @@ class WalletClient private constructor(
network: MoneroNetwork, network: MoneroNetwork,
nodeSelector: RemoteNodeSelector? = null, nodeSelector: RemoteNodeSelector? = null,
httpClient: OkHttpClient? = null, httpClient: OkHttpClient? = null,
coroutineContext: CoroutineContext = Dispatchers.Default + SupervisorJob(),
ioDispatcher: CoroutineDispatcher = Dispatchers.IO, ioDispatcher: CoroutineDispatcher = Dispatchers.IO,
): WalletClient { ): WalletClient {
val scope = CoroutineScope(coroutineContext)
val remoteNodeClient = nodeSelector?.let { val remoteNodeClient = nodeSelector?.let {
requireNotNull(httpClient) requireNotNull(httpClient)
RemoteNodeClient(it, httpClient, scope, ioDispatcher) RemoteNodeClient(it, httpClient, ioDispatcher)
} }
val (serviceConnection, service) = bindService(context) val (serviceConnection, service) = bindService(context)
return WalletClient(context, service, serviceConnection, network, remoteNodeClient) return WalletClient(context, service, serviceConnection, network, remoteNodeClient)