This commit is contained in:
tieong 2024-08-29 16:46:01 +02:00 committed by GitHub
commit 7401098318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 5 deletions

View File

@ -35,7 +35,7 @@ from .validators import ClickValidator, Required
def with_http(func): def with_http(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
async with aiohttp.ClientSession() as sess: async with aiohttp.ClientSession(trust_env=True) as sess:
try: try:
return await func(*args, sess=sess, **kwargs) return await func(*args, sess=sess, **kwargs)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
@ -50,7 +50,7 @@ def with_authenticated_http(func):
server, token = get_token(server) server, token = get_token(server)
if not token: if not token:
return return
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess: async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}, trust_env=True) as sess:
try: try:
return await func(*args, sess=sess, server=server, **kwargs) return await func(*args, sess=sess, server=server, **kwargs)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:

View File

@ -139,7 +139,7 @@ class Client(DBClient):
self._postinited = True self._postinited = True
self.cache[self.id] = self self.cache[self.id] = self
self.log = self.log.getChild(self.id) self.log = self.log.getChild(self.id)
self.http_client = ClientSession(loop=self.maubot.loop) self.http_client = ClientSession(loop=self.maubot.loop, trust_env=True)
self.references = set() self.references = set()
self.started = False self.started = False
self.sync_ok = True self.sync_ok = True

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import client as http, web from aiohttp import client as http, web
from urllib.request import getproxies
from ...client import Client from ...client import Client
from .base import routes from .base import routes
@ -45,8 +46,10 @@ async def proxy(request: web.Request) -> web.StreamResponse:
headers["X-Forwarded-For"] = f"{host}:{port}" headers["X-Forwarded-For"] = f"{host}:{port}"
data = await request.read() data = await request.read()
proxies = getproxies()
async with http.request( async with http.request(
request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data,
proxy=proxies["https"] if "https" in proxies else None
) as proxy_resp: ) as proxy_resp:
response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers) response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers)
await response.prepare(request) await response.prepare(request)

View File

@ -235,7 +235,7 @@ if appservice_listener:
async def main(): async def main():
http_client = ClientSession(loop=loop) http_client = ClientSession(loop=loop, trust_env=True)
global client, bot global client, bot