fastapi_login2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import inspect
  2. from datetime import datetime, timedelta, timezone
  3. from typing import Any, Awaitable, Callable, Collection, Dict, Optional, Type, Union
  4. import jwt
  5. from anyio.to_thread import run_sync
  6. from fastapi import FastAPI, Request, Response
  7. from fastapi.security import OAuth2PasswordBearer, SecurityScopes
  8. from starlette.middleware.base import BaseHTTPMiddleware
  9. from fastapi_login.exceptions import InsufficientScopeException, InvalidCredentialsException
  10. from fastapi_login.secrets import to_secret
  11. from fastapi_login.utils import ordered_partial
  12. SECRET_TYPE = Union[str, bytes]
  13. CUSTOM_EXCEPTION = Union[Type[Exception], Exception]
  14. class LoginManager(OAuth2PasswordBearer):
  15. def __init__(
  16. self,
  17. secret: Union[SECRET_TYPE, Dict[str, SECRET_TYPE]],
  18. token_url: str,
  19. algorithm="HS256",
  20. use_cookie=False,
  21. use_header=True,
  22. cookie_name: str = "access-token",
  23. not_authenticated_exception: CUSTOM_EXCEPTION = InvalidCredentialsException,
  24. default_expiry: timedelta = timedelta(minutes=15),
  25. scopes: Optional[Dict[str, str]] = None,
  26. out_of_scope_exception: CUSTOM_EXCEPTION = InsufficientScopeException,
  27. ):
  28. """
  29. Initializes LoginManager
  30. Args:
  31. algorithm (str): Should be "HS256" or "RS256" used to decrypt the JWT
  32. token_url (str): The url where the user can login to get the token
  33. use_cookie (bool): Set if cookies should be checked for the token
  34. use_header (bool): Set if headers should be checked for the token
  35. cookie_name (str): Name of the cookie to check for the token
  36. not_authenticated_exception (Union[Type[Exception], Exception]): Exception to raise when the user is not authenticated
  37. this defaults to `fastapi_login.exceptions.InvalidCredentialsException`
  38. default_expiry (datetime.timedelta): The default expiry time of the token, defaults to 15 minutes
  39. scopes (Dict[str, str]): Scopes argument of OAuth2PasswordBearer for more information see
  40. `https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/#oauth2-security-scheme`
  41. out_of_scope_exception (Union[Type[Exception], Exception]): Exception to raise when the user is out of scopes,
  42. if not set, default is `fastapi_login.exceptions.InsufficientScopeException`
  43. """
  44. if use_cookie is False and use_header is False:
  45. raise AttributeError(
  46. "use_cookie and use_header are both False one of them needs to be True"
  47. )
  48. if isinstance(secret, str):
  49. secret = secret.encode()
  50. self.secret = to_secret({"algorithms": algorithm, "secret": secret})
  51. self.algorithm = algorithm
  52. self.oauth_scheme = None
  53. self.use_cookie = use_cookie
  54. self.use_header = use_header
  55. self.cookie_name = cookie_name
  56. self.default_expiry = default_expiry
  57. # private
  58. self._user_callback: Optional[ordered_partial] = None
  59. self._not_authenticated_exception = not_authenticated_exception
  60. self._out_of_scope_exception = out_of_scope_exception
  61. # we take over the exception raised possibly by setting auto_error to False
  62. super().__init__(tokenUrl=token_url, auto_error=False, scopes=scopes)
  63. @property
  64. def out_of_scope_exception(self):
  65. """
  66. Exception raised when the user is out of scope.
  67. Defaults to `fastapi_login.exceptions.InsufficientScopeException`
  68. """
  69. return self._out_of_scope_exception
  70. @property
  71. def not_authenticated_exception(self):
  72. """
  73. Exception raised when no (valid) token is present.
  74. Defaults to `fastapi_login.exceptions.InvalidCredentialsException`
  75. """
  76. return self._not_authenticated_exception
  77. def user_loader(self, *args, **kwargs) -> Union[Callable, Callable[..., Awaitable]]:
  78. """
  79. This sets the callback to retrieve the user.
  80. The function should take an unique identifier like an email
  81. and return the user object or None.
  82. Basic usage:
  83. >>> from fastapi import FastAPI
  84. >>> from fastapi_login import LoginManager
  85. >>> app = FastAPI()
  86. >>> # use import secrets; print(secrets.token_hex(24)) to get a suitable secret key
  87. >>> SECRET = "super-secret"
  88. >>> manager = LoginManager(SECRET, token_url="Login")
  89. >>> manager.user_loader()(get_user)
  90. >>> @manager.user_loader(...) # Arguments and keyword arguments declared here are passed on
  91. >>> def get_user(user_identifier, ...):
  92. ... # get user logic here
  93. Args:
  94. args: Positional arguments to pass on to the decorated method
  95. kwargs: Keyword arguments to pass on to the decorated method
  96. Returns:
  97. The callback
  98. """
  99. def decorator(callback: Union[Callable, Callable[..., Awaitable]]):
  100. """
  101. The actual setter of the load_user callback
  102. Args:
  103. callback (Callable or Awaitable): The callback which returns the user
  104. Returns:
  105. Partial of the callback with given args and keyword arguments already set
  106. """
  107. self._user_callback = ordered_partial(callback, *args, **kwargs)
  108. return callback
  109. return decorator
  110. def _get_payload(self, token: str) -> Dict[str, Any]:
  111. """
  112. Returns the decoded token payload.
  113. If failed, raises `LoginManager.not_authenticated_exception`
  114. Args:
  115. token (str): The token to decode
  116. Returns:
  117. Payload of the token
  118. Raises:
  119. LoginManager.not_authenticated_exception: The token is invalid or None was returned by `_load_user`
  120. """
  121. try:
  122. payload = jwt.decode(
  123. token, self.secret.secret_for_decode, algorithms=[self.algorithm]
  124. )
  125. return payload
  126. # This includes all errors raised by pyjwt
  127. except jwt.PyJWTError:
  128. raise self.not_authenticated_exception
  129. def _has_scopes(
  130. self, payload: Dict[str, Any], required_scopes: Optional[SecurityScopes]
  131. ) -> bool:
  132. """
  133. Returns true if the required scopes are present in the token
  134. Args:
  135. payload (Dict[str, Any]): The decoded JWT payload
  136. required_scopes: The scopes required to access this route
  137. Returns:
  138. True if the required scopes are contained in the tokens payload
  139. """
  140. if required_scopes is None or not required_scopes.scopes:
  141. # According to RFC 6749, the scopes are optional
  142. return True
  143. # when the manager was invoked using fastapi.Security(manager, scopes=[...])
  144. # we have to check if all required scopes are contained in the token
  145. provided_scopes = payload.get("scopes", [])
  146. # Check if enough scopes are present
  147. if len(provided_scopes) < len(required_scopes.scopes):
  148. return False
  149. # Check if all required scopes are present
  150. elif any(scope not in provided_scopes for scope in required_scopes.scopes):
  151. return False
  152. return True
  153. def has_scopes(self, token: str, required_scopes: SecurityScopes) -> bool:
  154. """
  155. Combines `_get_payload` and `_has_scopes` to check if the token has the required scopes
  156. Args:
  157. token (str): The token to decode
  158. required_scopes: The scopes required to access this route
  159. Returns:
  160. True if the required scopes are contained in the tokens payload
  161. """
  162. payload = self._get_payload(token)
  163. return self._has_scopes(payload, required_scopes)
  164. async def _get_current_user(self, payload: Dict[str, Any]):
  165. """
  166. This decodes the jwt based on the secret and the algorithm set on the instance.
  167. If the token is correctly formatted and the user is found the user object
  168. is returned else this raises `LoginManager.not_authenticated_exception`
  169. Args:
  170. payload (Dict[str, Any]): The decoded JWT payload
  171. Returns:
  172. The user object returned by the instances `_user_callback`
  173. Raises:
  174. LoginManager.not_authenticated_exception: The token is invalid or None was returned by `_load_user`
  175. """
  176. # the identifier should be stored under the sub (subject) key
  177. user_identifier = payload.get("sub")
  178. if user_identifier is None:
  179. raise self.not_authenticated_exception
  180. user = await self._load_user(user_identifier)
  181. if user is None:
  182. raise self.not_authenticated_exception
  183. return user
  184. async def get_current_user(self, token: str) -> Any:
  185. """
  186. Combines `_get_payload` and `_get_current_user` to get the user object
  187. Args:
  188. token (str): The encoded jwt token
  189. Returns:
  190. The user object returned by the instances `_user_callback`
  191. Raises:
  192. LoginManager.not_authenticated_exception: The token is invalid or None was returned by `_load_user`
  193. """
  194. payload = self._get_payload(token)
  195. return await self._get_current_user(payload)
  196. async def _load_user(self, identifier: Any):
  197. """
  198. This loads the user using the user_callback
  199. Args:
  200. identifier (Any): The user identifier expected by `_user_callback`
  201. Returns:
  202. The user object returned by `_user_callback` or None
  203. Raises:
  204. Exception: When no ``user_loader`` has been set
  205. """
  206. if self._user_callback is None:
  207. raise Exception("Missing user_loader callback")
  208. if inspect.iscoroutinefunction(self._user_callback):
  209. user = await self._user_callback(identifier)
  210. else:
  211. user = await run_sync(self._user_callback, identifier)
  212. return user
  213. def create_access_token(
  214. self,
  215. *,
  216. data: dict,
  217. expires: Optional[timedelta] = None,
  218. scopes: Optional[Collection[str]] = None,
  219. ) -> str:
  220. """
  221. Helper function to create the encoded access token using
  222. the provided secret and the algorithm of the LoginManager instance
  223. Args:
  224. data (dict): The data which should be stored in the token
  225. expires (datetime.timedelta): An optional timedelta in which the token expires.
  226. Defaults to 15 minutes
  227. scopes (Collection): Optional scopes the token user has access to.
  228. Returns:
  229. The encoded JWT with the data and the expiry. The expiry is
  230. available under the 'exp' key
  231. """
  232. to_encode = data.copy()
  233. if expires:
  234. expires_in = datetime.now(timezone.utc) + expires
  235. else:
  236. expires_in = datetime.now(timezone.utc) + self.default_expiry
  237. to_encode.update({"exp": expires_in})
  238. if scopes is not None:
  239. unique_scopes = set(scopes)
  240. to_encode.update({"scopes": list(unique_scopes)})
  241. return jwt.encode(to_encode, self.secret.secret_for_encode, self.algorithm)
  242. def set_cookie(self, response: Response, token: str) -> None:
  243. """
  244. Utility function to set a cookie containing token on the response
  245. Args:
  246. response (fastapi.Response): The response which is send back
  247. token (str): The created JWT
  248. """
  249. response.set_cookie(key=self.cookie_name, value=token, httponly=True)
  250. def _token_from_cookie(self, request: Request) -> Optional[str]:
  251. """
  252. Checks the requests cookies for cookies with the value of`self.cookie_name` as name
  253. Args:
  254. request (fastapi.Request): The request to the route, normally filled in automatically
  255. Returns:
  256. The access token found in the cookies of the request or None
  257. """
  258. return request.cookies.get(self.cookie_name) or None
  259. async def _get_token(self, request: Request):
  260. """
  261. Tries to extract the token from the request, based on self.use_header and self.use_cookie
  262. Args:
  263. request: The request containing the token
  264. Returns:
  265. The in the request contained encoded JWT token
  266. Raises:
  267. LoginManager.not_authenticated_exception if no token is present
  268. """
  269. token = None
  270. if self.use_cookie:
  271. token = self._token_from_cookie(request)
  272. if not token and self.use_header:
  273. token = await super(LoginManager, self).__call__(request)
  274. if not token:
  275. raise self.not_authenticated_exception
  276. return token
  277. async def __call__(
  278. self,
  279. request: Request,
  280. security_scopes: SecurityScopes = None, # type: ignore
  281. ) -> Any:
  282. """
  283. Provides the functionality to act as a Dependency
  284. Args:
  285. request (fastapi.Request):The incoming request, this is set automatically
  286. by FastAPI
  287. Returns:
  288. The user object or None
  289. Raises:
  290. LoginManager.not_authenticated_exception: If set by the user and `self.auto_error` is set to False
  291. """
  292. token = await self._get_token(request)
  293. payload = self._get_payload(token)
  294. if not self._has_scopes(payload, security_scopes):
  295. raise self._out_of_scope_exception
  296. return await self._get_current_user(payload)
  297. async def optional(self, request: Request, security_scopes: SecurityScopes = None): # type: ignore
  298. """
  299. Acts as a dependency which catches all errors and returns `None` instead
  300. """
  301. try:
  302. user = await self.__call__(request, security_scopes)
  303. except Exception:
  304. return None
  305. else:
  306. return user
  307. def attach_middleware(self, app: FastAPI):
  308. """
  309. Add the instance as a middleware, which adds the user object, if present,
  310. to the request state
  311. Args:
  312. app (fastapi.FastAPI): FastAPI application
  313. """
  314. async def __set_user(request: Request, call_next):
  315. try:
  316. request.state.user = await self.__call__(request)
  317. except Exception:
  318. # An error occurred while getting the user
  319. # as middlewares are called for every incoming request
  320. # it's not a good idea to return the Exception
  321. # so we set the user to None
  322. request.state.user = None
  323. return await call_next(request)
  324. app.add_middleware(BaseHTTPMiddleware, dispatch=__set_user)