mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-08 14:18:08 -05:00
main: Pass our store path to the nio client after logging in.
This commit is contained in:
parent
d51d47334b
commit
710939e5bd
92
main.py
92
main.py
@ -2,9 +2,14 @@
|
||||
|
||||
import attr
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
|
||||
from aiohttp import web, ClientSession
|
||||
from nio import AsyncClient, LoginResponse
|
||||
from appdirs import user_data_dir
|
||||
from json import JSONDecodeError
|
||||
|
||||
HOMESERVER = "https://localhost:8448"
|
||||
|
||||
@ -18,37 +23,105 @@ class ProxyDaemon:
|
||||
client_sessions = attr.ib(init=False, default=attr.Factory(dict))
|
||||
default_session = attr.ib(init=False, default=None)
|
||||
|
||||
def get_access_token(self, request):
|
||||
# type: (aiohttp.BaseRequest) -> str
|
||||
"""Extract the access token from the request.
|
||||
|
||||
This method extracts the access token either from the query string or
|
||||
from the Authorization header of the request.
|
||||
|
||||
Returns the access token if it was found.
|
||||
"""
|
||||
access_token = request.query.get("access_token", "")
|
||||
|
||||
if not access_token:
|
||||
access_token = request.headers.get(
|
||||
"Authorization",
|
||||
""
|
||||
).strip("Bearer ")
|
||||
|
||||
return access_token
|
||||
|
||||
async def router(self, request):
|
||||
path = request.path
|
||||
method = request.method
|
||||
data = await request.text()
|
||||
headers = request.headers
|
||||
params = request.query
|
||||
|
||||
print(method, path, data)
|
||||
|
||||
if not self.default_session:
|
||||
self.default_session = ClientSession()
|
||||
session = None
|
||||
|
||||
async with self.default_session.request(
|
||||
token = self.get_access_token(request)
|
||||
client = self.client_sessions.get(token, None)
|
||||
|
||||
if client:
|
||||
session = client.client_session
|
||||
|
||||
if not session:
|
||||
if not self.default_session:
|
||||
self.default_session = ClientSession()
|
||||
session = self.default_session
|
||||
|
||||
async with session.request(
|
||||
method,
|
||||
HOMESERVER + path,
|
||||
data=data,
|
||||
params=params,
|
||||
headers=headers,
|
||||
proxy=self.proxy,
|
||||
ssl=False
|
||||
) as resp:
|
||||
print("Returning resp {}".format(resp))
|
||||
return(web.Response(text=await resp.text()))
|
||||
|
||||
async def login(self, request):
|
||||
json = await request.json()
|
||||
try:
|
||||
body = await request.json()
|
||||
except JSONDecodeError:
|
||||
# TODO what to do here, quaternion retries the login if we raise an
|
||||
# exception here, throws an error if we send out an 400 and hangs
|
||||
# if we forward it to the router() method.
|
||||
print("JSON ERROR IN LOGIN")
|
||||
raise
|
||||
# return web.Response(
|
||||
# status=400,
|
||||
# text=json.dumps({
|
||||
# "errcode": "M_NOT_JSON",
|
||||
# "error": "Request did not contain valid JSON."
|
||||
# })
|
||||
# )
|
||||
|
||||
user = json.get("user", "")
|
||||
password = json.get("password", "")
|
||||
device_id = json.get("device_id", "")
|
||||
device_name = json.get("initial_device_display_name", "")
|
||||
print("Login request")
|
||||
print(body)
|
||||
|
||||
identifier = body.get("identifier", None)
|
||||
|
||||
if identifier:
|
||||
user = identifier.get("user", None)
|
||||
|
||||
if not user:
|
||||
user = body.get("user", "")
|
||||
else:
|
||||
user = body.get("user", "")
|
||||
|
||||
password = body.get("password", "")
|
||||
device_id = body.get("device_id", "")
|
||||
device_name = body.get("initial_device_display_name", "")
|
||||
|
||||
store_path = user_data_dir("pantalaimon", "")
|
||||
|
||||
try:
|
||||
os.makedirs(store_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
client = AsyncClient(
|
||||
HOMESERVER,
|
||||
user,
|
||||
device_id,
|
||||
store_path=store_path,
|
||||
ssl=self.ssl,
|
||||
proxy=self.proxy
|
||||
)
|
||||
@ -59,6 +132,9 @@ class ProxyDaemon:
|
||||
|
||||
if isinstance(response, LoginResponse):
|
||||
self.client_sessions[response.access_token] = client
|
||||
else:
|
||||
# TODO close the client and its session.
|
||||
pass
|
||||
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
|
Loading…
Reference in New Issue
Block a user