connections.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from pathlib import Path
  2. from typing import Mapping, MutableMapping, Optional
  3. from urllib.parse import urlparse
  4. import aiohttp
  5. import requests
  6. from aphrodite.version import __version__ as APHRODITE_VERSION
  7. class HTTPConnection:
  8. """Helper class to send HTTP requests."""
  9. def __init__(self, *, reuse_client: bool = True) -> None:
  10. super().__init__()
  11. self.reuse_client = reuse_client
  12. self._sync_client: Optional[requests.Session] = None
  13. self._async_client: Optional[aiohttp.ClientSession] = None
  14. def get_sync_client(self) -> requests.Session:
  15. if self._sync_client is None or not self.reuse_client:
  16. self._sync_client = requests.Session()
  17. return self._sync_client
  18. # NOTE: We intentionally use an async function even though it is not
  19. # required, so that the client is only accessible inside async event loop
  20. async def get_async_client(self) -> aiohttp.ClientSession:
  21. if self._async_client is None or not self.reuse_client:
  22. self._async_client = aiohttp.ClientSession()
  23. return self._async_client
  24. def _validate_http_url(self, url: str):
  25. parsed_url = urlparse(url)
  26. if parsed_url.scheme not in ("http", "https"):
  27. raise ValueError("Invalid HTTP URL: A valid HTTP URL "
  28. "must have scheme 'http' or 'https'.")
  29. def _headers(self, **extras: str) -> MutableMapping[str, str]:
  30. return {"User-Agent": f"Aphrodite/{APHRODITE_VERSION}", **extras}
  31. def get_response(
  32. self,
  33. url: str,
  34. *,
  35. stream: bool = False,
  36. timeout: Optional[float] = None,
  37. extra_headers: Optional[Mapping[str, str]] = None,
  38. ):
  39. self._validate_http_url(url)
  40. client = self.get_sync_client()
  41. extra_headers = extra_headers or {}
  42. return client.get(url,
  43. headers=self._headers(**extra_headers),
  44. stream=stream,
  45. timeout=timeout)
  46. async def get_async_response(
  47. self,
  48. url: str,
  49. *,
  50. timeout: Optional[float] = None,
  51. extra_headers: Optional[Mapping[str, str]] = None,
  52. ):
  53. self._validate_http_url(url)
  54. client = await self.get_async_client()
  55. extra_headers = extra_headers or {}
  56. return client.get(url,
  57. headers=self._headers(**extra_headers),
  58. timeout=timeout)
  59. def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
  60. with self.get_response(url, timeout=timeout) as r:
  61. r.raise_for_status()
  62. return r.content
  63. async def async_get_bytes(
  64. self,
  65. url: str,
  66. *,
  67. timeout: Optional[float] = None,
  68. ) -> bytes:
  69. async with await self.get_async_response(url, timeout=timeout) as r:
  70. r.raise_for_status()
  71. return await r.read()
  72. def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
  73. with self.get_response(url, timeout=timeout) as r:
  74. r.raise_for_status()
  75. return r.text
  76. async def async_get_text(
  77. self,
  78. url: str,
  79. *,
  80. timeout: Optional[float] = None,
  81. ) -> str:
  82. async with await self.get_async_response(url, timeout=timeout) as r:
  83. r.raise_for_status()
  84. return await r.text()
  85. def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
  86. with self.get_response(url, timeout=timeout) as r:
  87. r.raise_for_status()
  88. return r.json()
  89. async def async_get_json(
  90. self,
  91. url: str,
  92. *,
  93. timeout: Optional[float] = None,
  94. ) -> str:
  95. async with await self.get_async_response(url, timeout=timeout) as r:
  96. r.raise_for_status()
  97. return await r.json()
  98. def download_file(
  99. self,
  100. url: str,
  101. save_path: Path,
  102. *,
  103. timeout: Optional[float] = None,
  104. chunk_size: int = 128,
  105. ) -> Path:
  106. with self.get_response(url, timeout=timeout) as r:
  107. r.raise_for_status()
  108. with save_path.open("wb") as f:
  109. for chunk in r.iter_content(chunk_size):
  110. f.write(chunk)
  111. return save_path
  112. async def async_download_file(
  113. self,
  114. url: str,
  115. save_path: Path,
  116. *,
  117. timeout: Optional[float] = None,
  118. chunk_size: int = 128,
  119. ) -> Path:
  120. async with await self.get_async_response(url, timeout=timeout) as r:
  121. r.raise_for_status()
  122. with save_path.open("wb") as f:
  123. async for chunk in r.content.iter_chunked(chunk_size):
  124. f.write(chunk)
  125. return save_path
  126. global_http_connection = HTTPConnection()
  127. """The global :class:`HTTPConnection` instance used by Aphrodite."""