Xiaopeng Zhang 3 hónapja
szülő
commit
8c00eb2954
2 módosított fájl, 444 hozzáadás és 19 törlés
  1. 409 0
      fastapi_login2.py
  2. 35 19
      main.py

+ 409 - 0
fastapi_login2.py

@@ -0,0 +1,409 @@
+import inspect
+from datetime import datetime, timedelta, timezone
+from typing import Any, Awaitable, Callable, Collection, Dict, Optional, Type, Union
+
+import jwt
+from anyio.to_thread import run_sync
+from fastapi import FastAPI, Request, Response
+from fastapi.security import OAuth2PasswordBearer, SecurityScopes
+from starlette.middleware.base import BaseHTTPMiddleware
+
+from fastapi_login.exceptions import InsufficientScopeException, InvalidCredentialsException
+from fastapi_login.secrets import to_secret
+from fastapi_login.utils import ordered_partial
+
+SECRET_TYPE = Union[str, bytes]
+CUSTOM_EXCEPTION = Union[Type[Exception], Exception]
+
+
+class LoginManager(OAuth2PasswordBearer):
+    def __init__(
+        self,
+        secret: Union[SECRET_TYPE, Dict[str, SECRET_TYPE]],
+        token_url: str,
+        algorithm="HS256",
+        use_cookie=False,
+        use_header=True,
+        cookie_name: str = "access-token",
+        not_authenticated_exception: CUSTOM_EXCEPTION = InvalidCredentialsException,
+        default_expiry: timedelta = timedelta(minutes=15),
+        scopes: Optional[Dict[str, str]] = None,
+        out_of_scope_exception: CUSTOM_EXCEPTION = InsufficientScopeException,
+    ):
+        """
+        Initializes LoginManager
+
+        Args:
+            algorithm (str): Should be "HS256" or "RS256" used to decrypt the JWT
+            token_url (str): The url where the user can login to get the token
+            use_cookie (bool): Set if cookies should be checked for the token
+            use_header (bool): Set if headers should be checked for the token
+            cookie_name (str): Name of the cookie to check for the token
+            not_authenticated_exception (Union[Type[Exception], Exception]): Exception to raise when the user is not authenticated
+                this defaults to `fastapi_login.exceptions.InvalidCredentialsException`
+            default_expiry (datetime.timedelta): The default expiry time of the token, defaults to 15 minutes
+            scopes (Dict[str, str]): Scopes argument of OAuth2PasswordBearer for more information see
+                `https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/#oauth2-security-scheme`
+            out_of_scope_exception (Union[Type[Exception], Exception]): Exception to raise when the user is out of scopes,
+                if not set, default is `fastapi_login.exceptions.InsufficientScopeException`
+        """
+        if use_cookie is False and use_header is False:
+            raise AttributeError(
+                "use_cookie and use_header are both False one of them needs to be True"
+            )
+        if isinstance(secret, str):
+            secret = secret.encode()
+
+        self.secret = to_secret({"algorithms": algorithm, "secret": secret})
+        self.algorithm = algorithm
+        self.oauth_scheme = None
+        self.use_cookie = use_cookie
+        self.use_header = use_header
+        self.cookie_name = cookie_name
+        self.default_expiry = default_expiry
+
+        # private
+        self._user_callback: Optional[ordered_partial] = None
+        self._not_authenticated_exception = not_authenticated_exception
+        self._out_of_scope_exception = out_of_scope_exception
+
+        # we take over the exception raised possibly by setting auto_error to False
+        super().__init__(tokenUrl=token_url, auto_error=False, scopes=scopes)
+
+    @property
+    def out_of_scope_exception(self):
+        """
+        Exception raised when the user is out of scope.
+        Defaults to `fastapi_login.exceptions.InsufficientScopeException`
+        """
+        return self._out_of_scope_exception
+
+    @property
+    def not_authenticated_exception(self):
+        """
+        Exception raised when no (valid) token is present.
+        Defaults to `fastapi_login.exceptions.InvalidCredentialsException`
+        """
+        return self._not_authenticated_exception
+
+    def user_loader(self, *args, **kwargs) -> Union[Callable, Callable[..., Awaitable]]:
+        """
+        This sets the callback to retrieve the user.
+        The function should take an unique identifier like an email
+        and return the user object or None.
+
+        Basic usage:
+
+            >>> from fastapi import FastAPI
+            >>> from fastapi_login import LoginManager
+
+            >>> app = FastAPI()
+            >>> # use import secrets; print(secrets.token_hex(24)) to get a suitable secret key
+            >>> SECRET = "super-secret"
+
+            >>> manager = LoginManager(SECRET, token_url="Login")
+
+            >>> manager.user_loader()(get_user)
+
+            >>> @manager.user_loader(...)  # Arguments and keyword arguments declared here are passed on
+            >>> def get_user(user_identifier, ...):
+            ...     # get user logic here
+
+        Args:
+            args: Positional arguments to pass on to the decorated method
+            kwargs: Keyword arguments to pass on to the decorated method
+
+        Returns:
+            The callback
+        """
+
+        def decorator(callback: Union[Callable, Callable[..., Awaitable]]):
+            """
+            The actual setter of the load_user callback
+            Args:
+                callback (Callable or Awaitable): The callback which returns the user
+
+            Returns:
+                Partial of the callback with given args and keyword arguments already set
+            """
+            self._user_callback = ordered_partial(callback, *args, **kwargs)
+            return callback
+
+        return decorator
+
+    def _get_payload(self, token: str) -> Dict[str, Any]:
+        """
+        Returns the decoded token payload.
+        If failed, raises `LoginManager.not_authenticated_exception`
+
+        Args:
+            token (str): The token to decode
+
+        Returns:
+            Payload of the token
+
+        Raises:
+            LoginManager.not_authenticated_exception: The token is invalid or None was returned by `_load_user`
+        """
+        try:
+            payload = jwt.decode(
+                token, self.secret.secret_for_decode, algorithms=[self.algorithm]
+            )
+            return payload
+
+        # This includes all errors raised by pyjwt
+        except jwt.PyJWTError:
+            raise self.not_authenticated_exception
+
+    def _has_scopes(
+        self, payload: Dict[str, Any], required_scopes: Optional[SecurityScopes]
+    ) -> bool:
+        """
+        Returns true if the required scopes are present in the token
+
+        Args:
+            payload (Dict[str, Any]): The decoded JWT payload
+            required_scopes: The scopes required to access this route
+
+        Returns:
+            True if the required scopes are contained in the tokens payload
+        """
+        if required_scopes is None or not required_scopes.scopes:
+            # According to RFC 6749, the scopes are optional
+            return True
+
+        # when the manager was invoked using fastapi.Security(manager, scopes=[...])
+        # we have to check if all required scopes are contained in the token
+        provided_scopes = payload.get("scopes", [])
+        # Check if enough scopes are present
+        if len(provided_scopes) < len(required_scopes.scopes):
+            return False
+        # Check if all required scopes are present
+        elif any(scope not in provided_scopes for scope in required_scopes.scopes):
+            return False
+
+        return True
+
+    def has_scopes(self, token: str, required_scopes: SecurityScopes) -> bool:
+        """
+        Combines `_get_payload` and `_has_scopes` to check if the token has the required scopes
+
+        Args:
+            token (str): The token to decode
+            required_scopes: The scopes required to access this route
+
+        Returns:
+            True if the required scopes are contained in the tokens payload
+        """
+        payload = self._get_payload(token)
+        return self._has_scopes(payload, required_scopes)
+
+    async def _get_current_user(self, payload: Dict[str, Any]):
+        """
+        This decodes the jwt based on the secret and the algorithm set on the instance.
+        If the token is correctly formatted and the user is found the user object
+        is returned else this raises `LoginManager.not_authenticated_exception`
+
+        Args:
+            payload (Dict[str, Any]): The decoded JWT payload
+
+        Returns:
+            The user object returned by the instances `_user_callback`
+
+        Raises:
+            LoginManager.not_authenticated_exception: The token is invalid or None was returned by `_load_user`
+        """
+        # the identifier should be stored under the sub (subject) key
+        user_identifier = payload.get("sub")
+        if user_identifier is None:
+            raise self.not_authenticated_exception
+
+        user = await self._load_user(user_identifier)
+        if user is None:
+            raise self.not_authenticated_exception
+
+        return user
+
+    async def get_current_user(self, token: str) -> Any:
+        """
+        Combines `_get_payload` and `_get_current_user` to get the user object
+
+        Args:
+            token (str): The encoded jwt token
+
+        Returns:
+            The user object returned by the instances `_user_callback`
+
+        Raises:
+            LoginManager.not_authenticated_exception: The token is invalid or None was returned by `_load_user`
+        """
+        payload = self._get_payload(token)
+        return await self._get_current_user(payload)
+
+    async def _load_user(self, identifier: Any):
+        """
+        This loads the user using the user_callback
+
+        Args:
+            identifier (Any): The user identifier expected by `_user_callback`
+
+        Returns:
+            The user object returned by `_user_callback` or None
+
+        Raises:
+            Exception: When no ``user_loader`` has been set
+        """
+        if self._user_callback is None:
+            raise Exception("Missing user_loader callback")
+
+        if inspect.iscoroutinefunction(self._user_callback):
+            user = await self._user_callback(identifier)
+        else:
+            user = await run_sync(self._user_callback, identifier)
+
+        return user
+
+    def create_access_token(
+        self,
+        *,
+        data: dict,
+        expires: Optional[timedelta] = None,
+        scopes: Optional[Collection[str]] = None,
+    ) -> str:
+        """
+        Helper function to create the encoded access token using
+        the provided secret and the algorithm of the LoginManager instance
+
+        Args:
+            data (dict): The data which should be stored in the token
+            expires (datetime.timedelta):  An optional timedelta in which the token expires.
+                Defaults to 15 minutes
+            scopes (Collection): Optional scopes the token user has access to.
+
+        Returns:
+            The encoded JWT with the data and the expiry. The expiry is
+            available under the 'exp' key
+        """
+
+        to_encode = data.copy()
+
+        if expires:
+            expires_in = datetime.now(timezone.utc) + expires
+        else:
+            expires_in = datetime.now(timezone.utc) + self.default_expiry
+
+        to_encode.update({"exp": expires_in})
+
+        if scopes is not None:
+            unique_scopes = set(scopes)
+            to_encode.update({"scopes": list(unique_scopes)})
+
+        return jwt.encode(to_encode, self.secret.secret_for_encode, self.algorithm)
+
+    def set_cookie(self, response: Response, token: str) -> None:
+        """
+        Utility function to set a cookie containing token on the response
+
+        Args:
+            response (fastapi.Response): The response which is send back
+            token (str): The created JWT
+        """
+        response.set_cookie(key=self.cookie_name, value=token, httponly=True)
+
+    def _token_from_cookie(self, request: Request) -> Optional[str]:
+        """
+        Checks the requests cookies for cookies with the value of`self.cookie_name` as name
+
+        Args:
+            request (fastapi.Request): The request to the route, normally filled in automatically
+
+        Returns:
+            The access token found in the cookies of the request or None
+        """
+        return request.cookies.get(self.cookie_name) or None
+
+    async def _get_token(self, request: Request):
+        """
+        Tries to extract the token from the request, based on self.use_header and self.use_cookie
+
+        Args:
+            request: The request containing the token
+
+        Returns:
+            The in the request contained encoded JWT token
+
+        Raises:
+            LoginManager.not_authenticated_exception if no token is present
+        """
+        token = None
+        if self.use_cookie:
+            token = self._token_from_cookie(request)
+
+        if not token and self.use_header:
+            token = await super(LoginManager, self).__call__(request)
+
+        if not token:
+            raise self.not_authenticated_exception
+
+        return token
+
+    async def __call__(
+        self,
+        request: Request,
+        security_scopes: SecurityScopes = None,  # type: ignore
+    ) -> Any:
+        """
+        Provides the functionality to act as a Dependency
+
+        Args:
+            request (fastapi.Request):The incoming request, this is set automatically
+                by FastAPI
+
+        Returns:
+            The user object or None
+
+        Raises:
+            LoginManager.not_authenticated_exception: If set by the user and `self.auto_error` is set to False
+
+        """
+        token = await self._get_token(request)
+        payload = self._get_payload(token)
+
+        if not self._has_scopes(payload, security_scopes):
+            raise self._out_of_scope_exception
+
+        return await self._get_current_user(payload)
+
+    async def optional(self, request: Request, security_scopes: SecurityScopes = None):  # type: ignore
+        """
+        Acts as a dependency which catches all errors and returns `None` instead
+        """
+        try:
+            user = await self.__call__(request, security_scopes)
+        except Exception:
+            return None
+        else:
+            return user
+
+    def attach_middleware(self, app: FastAPI):
+        """
+        Add the instance as a middleware, which adds the user object, if present,
+        to the request state
+
+        Args:
+            app (fastapi.FastAPI): FastAPI application
+        """
+
+        async def __set_user(request: Request, call_next):
+            try:
+                request.state.user = await self.__call__(request)
+            except Exception:
+                # An error occurred while getting the user
+                # as middlewares are called for every incoming request
+                # it's not a good idea to return the Exception
+                # so we set the user to None
+                request.state.user = None
+
+            return await call_next(request)
+
+        app.add_middleware(BaseHTTPMiddleware, dispatch=__set_user)

+ 35 - 19
main.py

@@ -1,16 +1,17 @@
 import xml.etree.ElementTree as ET
 import xml.etree.ElementTree as ET
 from typing import Union
 from typing import Union
-from fastapi import FastAPI
+from fastapi import FastAPI, HTTPException, status, Depends
 import os
 import os
 import uuid
 import uuid
 import re
 import re
 import zipfile
 import zipfile
 import json
 import json
 import chromadb
 import chromadb
+from datetime import timedelta
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
-from sentence_transformers  import SentenceTransformer
-model = SentenceTransformer("BAAI/bge-small-zh-v1.5")
+##from sentence_transformers  import SentenceTransformer
+model = None##SentenceTransformer("BAAI/bge-small-zh-v1.5")
 import base64
 import base64
 from pydantic import BaseModel
 from pydantic import BaseModel
 from subdir import service
 from subdir import service
@@ -21,8 +22,8 @@ import numpy as np
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from pymongo import AsyncMongoClient
 from pymongo import AsyncMongoClient
 client = AsyncMongoClient()
 client = AsyncMongoClient()
-chroma_client = chromadb.HttpClient(host='localhost', port=8000)
-from fastapi.responses import FileResponse
+chroma_client = None##chromadb.HttpClient(host='localhost', port=8000)
+from fastapi.responses import FileResponse, RedirectResponse
 from fastapi_cache import FastAPICache
 from fastapi_cache import FastAPICache
 from inmemory import InMemoryBackend
 from inmemory import InMemoryBackend
 from fastapi_cache.decorator import cache
 from fastapi_cache.decorator import cache
@@ -30,6 +31,9 @@ from contextlib import asynccontextmanager
 from fastapi import WebSocket, WebSocketDisconnect
 from fastapi import WebSocket, WebSocketDisconnect
 from fastapi import UploadFile
 from fastapi import UploadFile
 from collections.abc import AsyncIterator
 from collections.abc import AsyncIterator
+
+from fastapi_login import LoginManager
+
 class ConnectionManager:
 class ConnectionManager:
     """Class defining socket events"""
     """Class defining socket events"""
     def __init__(self):
     def __init__(self):
@@ -56,6 +60,8 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
     yield
     yield
 
 
 
 
+SECRET = '400687251779e65aca22ef5bb52b9cc4c218c63571004aeb'
+
 app = FastAPI(lifespan=lifespan)
 app = FastAPI(lifespan=lifespan)
 
 
 
 
@@ -79,9 +85,12 @@ app.add_middleware(
 ##app.mount("/static", StaticFiles(directory="front/dist"), name="static")
 ##app.mount("/static", StaticFiles(directory="front/dist"), name="static")
 ##manager = ConnectionManager()
 ##manager = ConnectionManager()
 
 
+manager = LoginManager(SECRET, token_url='/token')
 
 
-
-
+@manager.user_loader()
+def load_user(id: str):  # could also be an asynchronous function
+    
+    return id
 
 
 
 
 
 
@@ -147,21 +156,21 @@ class Zjcs(BaseModel):
 
 
 
 
 @app.post("/outline2")
 @app.post("/outline2")
-async def read_root2(info: Info):
+async def read_root2(info: Info, user=Depends(manager)):
     
     
     return await db.getOutline(client, info.name)
     return await db.getOutline(client, info.name)
 
 
 
 
 
 
 @app.post("/detail2")
 @app.post("/detail2")
-async def read_detail2(info: Info):
+async def read_detail2(info: Info, user=Depends(manager)):
     
     
     return await db.getDetail(client, info.name)
     return await db.getDetail(client, info.name)
 
 
 
 
 
 
 @app.post("/baojiahuizong2/")
 @app.post("/baojiahuizong2/")
-async def read_bjhz2(info: InfoWithID):
+async def read_bjhz2(info: InfoWithID, user=Depends(manager)):
    
    
     raw = await db.getBjhz(client, info.name, info.id)
     raw = await db.getBjhz(client, info.name, info.id)
     raw2 = []
     raw2 = []
@@ -179,7 +188,7 @@ async def read_bjhz2(info: InfoWithID):
 
 
 
 
 @app.post("/guifeishuijin2/")
 @app.post("/guifeishuijin2/")
-async def read_gfsj2(info: InfoWithID):
+async def read_gfsj2(info: InfoWithID, user=Depends(manager)):
     raw = await db.getGfsj(client, info.name, info.id)
     raw = await db.getGfsj(client, info.name, info.id)
     raw2 = []
     raw2 = []
     for entry in raw:
     for entry in raw:
@@ -193,7 +202,7 @@ async def read_gfsj2(info: InfoWithID):
 
 
 
 
 @app.post("/qitaxiangmu2/")
 @app.post("/qitaxiangmu2/")
-async def read_qtxm2(info: InfoWithID):
+async def read_qtxm2(info: InfoWithID, user=Depends(manager)):
     raw = await db.getQtxm(client, info.name, info.id)
     raw = await db.getQtxm(client, info.name, info.id)
     raw2 = []
     raw2 = []
     for entry in raw:
     for entry in raw:
@@ -236,14 +245,14 @@ async def read_fbrgycl2(info: InfoWithID):
 
 
 
 
 @app.post("/rencaijihuizong2/")
 @app.post("/rencaijihuizong2/")
-async def read_rcjhz2(info: InfoWithID):
+async def read_rcjhz2(info: InfoWithID, user=Depends(manager)):
     return await db.getRcjhz(client, info.name, info.id)
     return await db.getRcjhz(client, info.name, info.id)
 
 
 
 
 
 
 
 
 @app.post("/qingdanxiangmu2/")
 @app.post("/qingdanxiangmu2/")
-async def read_qdxm2(info: InfoWithID):
+async def read_qdxm2(info: InfoWithID, user=Depends(manager)):
     
     
     return await db.getQdxm(client, info.name, info.id)
     return await db.getQdxm(client, info.name, info.id)
 
 
@@ -283,7 +292,7 @@ async def read_dercj(item : Dercj):
 
 
 
 
 @app.post("/zjcs/")
 @app.post("/zjcs/")
-async def read_zjcs(item : Zjcs):
+async def read_zjcs(item : Zjcs, user=Depends(manager)):
     raw = await db.getZjcs(client, item.name, item.bh)
     raw = await db.getZjcs(client, item.name, item.bh)
     raw2 = []
     raw2 = []
     for entry in raw:
     for entry in raw:
@@ -295,7 +304,7 @@ async def read_zjcs(item : Zjcs):
 
 
 
 
 @app.post("/djcs/")
 @app.post("/djcs/")
-async def read_djcs(item : Zjcs):
+async def read_djcs(item : Zjcs, user=Depends(manager)):
     raw = await db.getDjcs(client, item.name, item.bh)
     raw = await db.getDjcs(client, item.name, item.bh)
     raw2 = []
     raw2 = []
     for entry in raw:
     for entry in raw:
@@ -308,7 +317,7 @@ async def read_djcs(item : Zjcs):
 
 
 
 
 @app.post("/files2/")
 @app.post("/files2/")
-async def read_files2():
+async def read_files2(user=Depends(manager)):
     result = await db.list_files(client)
     result = await db.list_files(client)
     return result
     return result
 
 
@@ -542,7 +551,7 @@ async def cankao():
     return result
     return result
 
 
 @app.post("/qufei/")
 @app.post("/qufei/")
-async def read_qufei(r: Info):
+async def read_qufei(r: Info, user=Depends(manager)):
     return await db.getQufei(client, r.name)
     return await db.getQufei(client, r.name)
 
 
 @app.post("/jiagongcai/")
 @app.post("/jiagongcai/")
@@ -915,7 +924,7 @@ async def save(r: Info):
     return await db.save(client, data)
     return await db.save(client, data)
 
 
 @app.post("/statistics/")
 @app.post("/statistics/")
-async def statistics(r: Info):
+async def statistics(r: Info, user=Depends(manager)):
     
     
     ##print(data)
     ##print(data)
     return await db.statistics(client, r.name)
     return await db.statistics(client, r.name)
@@ -1060,3 +1069,10 @@ async def resolve(data):
 ##    except WebSocketDisconnect:
 ##    except WebSocketDisconnect:
 ##        manager.disconnect(websocket)
 ##        manager.disconnect(websocket)
 ##        ##await manager.broadcast(f"Client  left the chat")
 ##        ##await manager.broadcast(f"Client  left the chat")
+    
+
+@app.post("/token/")
+async def token(r : Info):
+    token = manager.create_access_token(data=dict(sub=r.name), expires=timedelta(hours=12))
+    return {'token' : token, 'token_type' : 'bearer'}
+