Study/fastapi

[udemy] FastAPI - The Complete Course 2025 (Beginner + Advanced) - 학습 정리 2

bluebamus 2025. 1. 10.

1. Authentication & Authorization

   1) bcrypt로 비밀번호 암호화 하기

from passlib.context import CryptContext

bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')

# 암호화
hashed_password=bcrypt_context.hash(create_user_request.password)

# 검증
if not bcrypt_context.verify(password, user.hashed_password):
        return False

 

   2) OAuth2PasswordRequestForm을 사용해 인증하기

      - 설명:
         - OAuth2PasswordRequestForm는 OAuth2의 Password Grant 흐름에서 사용되는 폼 데이터를 처리하는 클래스이다. 클라이언트가 사용자 인증을 위해 사용자 이름과 비밀번호를 제출할 때 이를 쉽게 처리할 수 있도록 설계되었다.

      - 주요 필드:
         - username: 사용자의 사용자 이름.
         - password: 사용자의 비밀번호.
         - scope: 인증 범위를 지정. (선택 사항, 기본값은 빈 문자열)
         - client_id: 클라이언트 ID (선택 사항).
         - client_secret: 클라이언트 비밀키 (선택 사항).


      -  사용 예: FastAPI 애플리케이션에서 OAuth2PasswordRequestForm을 의존성으로 사용하여 요청에서 사용자 이름과 비밀번호를 추출한다.

         - ChatGPT 코드 :

            - OAuth2PasswordRequestForm으로 인증하여 token 생성 후 token 검증

from datetime import datetime, timedelta
from typing import Optional

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext

# FastAPI 애플리케이션 생성
app = FastAPI()

# JWT 설정
SECRET_KEY = "your_secret_key"  # 보안용 비밀키. 환경 변수로 관리하는 것이 안전합니다.
ALGORITHM = "HS256"  # 사용할 암호화 알고리즘
ACCESS_TOKEN_EXPIRE_MINUTES = 30  # 액세스 토큰의 만료 시간 (30분)

# 가상 사용자 데이터베이스 (실제로는 DB와 연동해야 함)
fake_users_db = {
    "user1": {
        "username": "user1",
        "full_name": "John Doe",
        "email": "johndoe@example.com",
        "hashed_password": "$2b$12$KIXQG5G3Hp5sBW.KJxG4WeXQ56AfQOiAIkLPPBo5VrVVGPugMox4u",  # "password"의 bcrypt 해시값
        "disabled": False,
    }
}

# 비밀번호 해싱 및 검증을 위한 설정
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# OAuth2PasswordBearer: 인증 토큰을 `Authorization: Bearer <token>` 헤더에서 추출
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


# 비밀번호 검증 함수
def verify_password(plain_password: str, hashed_password: str) -> bool:
    """
    사용자가 입력한 비밀번호와 저장된 해시 비밀번호를 비교.
    """
    return pwd_context.verify(plain_password, hashed_password)


# 비밀번호 해싱 함수
def get_password_hash(password: str) -> str:
    """
    비밀번호를 해싱하여 반환.
    """
    return pwd_context.hash(password)


# 사용자 조회 함수
def get_user(db: dict, username: str) -> Optional[dict]:
    """
    데이터베이스에서 사용자 이름으로 사용자 정보를 검색.
    """
    return db.get(username)


# 사용자 인증 함수
def authenticate_user(db: dict, username: str, password: str) -> Optional[dict]:
    """
    사용자를 인증. 비밀번호와 사용자 정보를 검증.
    """
    user = get_user(db, username)
    if not user or not verify_password(password, user["hashed_password"]):
        return None
    return user


# JWT 토큰 생성 함수
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """
    JWT 액세스 토큰을 생성.
    - `data`: 토큰에 포함할 사용자 데이터 (예: {"sub": username}).
    - `expires_delta`: 토큰 만료 시간 설정.
    """
    to_encode = data.copy()
    expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
    to_encode.update({"exp": expire})  # 만료 시간 추가
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)  # JWT 생성
    return encoded_jwt


# 토큰 발급 엔드포인트
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    """
    사용자 이름과 비밀번호를 받아 JWT 토큰을 발급하는 엔드포인트.
    - `form_data`: OAuth2PasswordRequestForm으로 폼 데이터(username, password)를 추출.
    """
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    # 토큰 생성
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
    return {"access_token": access_token, "token_type": "bearer"}


# 현재 사용자 정보 가져오기 함수
async def get_current_user(token: str = Depends(oauth2_scheme)):
    """
    요청에서 JWT 토큰을 추출하고, 이를 검증하여 현재 사용자를 반환.
    - `token`: `Depends(oauth2_scheme)`로 Authorization 헤더에서 Bearer 토큰을 추출.
    """
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        # 토큰 디코딩 및 검증
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")  # 토큰에서 사용자 이름(sub) 추출
        if username is None:
            raise credentials_exception
    except JWTError:
        # JWT 검증 실패 시 예외 발생
        raise credentials_exception
    # 사용자 정보 검색
    user = get_user(fake_users_db, username)
    if user is None:
        raise credentials_exception
    return user  # 인증된 사용자 반환


# 보호된 엔드포인트: 현재 사용자 정보 반환
@app.get("/users/me")
async def read_users_me(current_user: dict = Depends(get_current_user)):
    """
    인증된 사용자만 접근할 수 있는 엔드포인트.
    - `current_user`: `Depends(get_current_user)`로 현재 사용자 정보를 추출.
    - 반환 값: 인증된 사용자의 정보를 JSON으로 반환.
    """
    return {
        "username": current_user["username"],
        "full_name": current_user["full_name"],
        "email": current_user["email"],
    }

 

         - 강의 코드 :

def authenticate_user(username: str, password: str, db):
    user = db.query(Users).filter(Users.username == username).first()
    if not user:
        return False
    if not bcrypt_context.verify(password, user.hashed_password):
        return False
    return user
    
    
@router.post("/token", response_model=Token)
async def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
                                 db: db_dependency):
    user = authenticate_user(form_data.username, form_data.password, db)
    if not user:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
                            detail='Could not validate user.')
    token = create_access_token(user.username, user.id, user.role, timedelta(minutes=20))

    return {'access_token': token, 'token_type': 'bearer'}

 

      - oauth2.py의 OAuth2PasswordRequestForm 클래스 주석 번역:

"""
이것은 OAuth2 패스워드 플로우에서 폼 데이터로 `username`과 `password`를 수집하기 위한 
의존성 클래스입니다.

OAuth2 명세는 패스워드 플로우에서 데이터가 (JSON 대신) 폼 데이터를 사용하여 수집되어야 하며,
`username`과 `password`라는 특정 필드를 가져야 한다고 규정합니다.

모든 초기화 매개변수는 요청에서 추출됩니다.

자세한 내용은 [FastAPI 문서의 패스워드와 Bearer를 사용한 간단한 OAuth2]
(https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/)에서 확인하세요.
"""

 

      - 점프 투 FastAPI 문서의 코드

         - 참고 : https://wikidocs.net/176934

         - 라이브러리 설치

pip install python-multipart
pip install "python-jose[cryptography]"
# 로그인 스키마
class Token(BaseModel):
    access_token: str
    token_type: str
    username: str
    
    
# 로그인 CRUD
def get_user(db: Session, username: str):
    return db.query(User).filter(User.username == username).first()
    
    
# 로그인 라우터
from datetime import timedelta, datetime

from fastapi import APIRouter, HTTPException
from fastapi import Depends
from fastapi.security import OAuth2PasswordRequestForm
from jose import jwt
from sqlalchemy.orm import Session
from starlette import status

from database import get_db
from domain.user import user_crud, user_schema
from domain.user.user_crud import pwd_context

ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24
SECRET_KEY = "4ab2fce7a6bd79e1c014396315ed322dd6edb1c5d975c6b74a2904135172c03c"
ALGORITHM = "HS256"

(... 생략 ...)

@router.post("/login", response_model=user_schema.Token)
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(),
                           db: Session = Depends(get_db)):

    # check user and password
    user = user_crud.get_user(db, form_data.username)
    if not user or not pwd_context.verify(form_data.password, user.password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )

    # make access token
    data = {
        "sub": user.username,
        "exp": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    }
    access_token = jwt.encode(data, SECRET_KEY, algorithm=ALGORITHM)

    return {
        "access_token": access_token,
        "token_type": "bearer",
        "username": user.username
    }

 

   3) OAuth2PasswordBearer을 사용해 인증하기

      - 설명:
         - OAuth2PasswordBearer는 Bearer 토큰을 사용하는 OAuth2 인증 스킴을 구현하기 위해 사용된다. 주로 API의 인증을 처리할 때 사용되며, 클라이언트가 Authorization 헤더에 Bearer 토큰을 포함하여 요청을 보낼 것을 기대한다.

      - 초기화 매개변수:
         - tokenUrl: 토큰을 발급받는 엔드포인트의 URL 경로를 지정한다.

 

      - 동작 방식:
         - 클라이언트가 Authorization: Bearer <token> 헤더를 통해 요청을 보낸다.
         - OAuth2PasswordBearer는 헤더에서 토큰을 추출하고, 의존성으로 이를 처리한다.
      - 사용 예: OAuth2PasswordBearer를 사용하여 요청에서 토큰을 추출하고 이를 인증에 활용한다.

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer

app = FastAPI()

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

@app.get("/users/me")
async def read_users_me(token: str = Depends(oauth2_scheme)):
    # 토큰을 검증하고 사용자 정보를 가져오는 로직 구현
    if token != "valid_token":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token",
            headers={"WWW-Authenticate": "Bearer"},
        )
    return {"user_id": "user123", "token": token}

 

   4) Refresh Token의 필요성

      1. 네트워크 탈취 vs. 로컬 컴퓨터 탈취

         - Refresh Token은 클라이언트와 서버 간 통신에서 탈취될 가능성을 줄이고, 로컬에서 안전하게 보관될 수 있도록 설계된다. 다음과 같은 이유로 네트워크 탈취를 더 큰 위협으로 간주한다:


         1) 네트워크 통신상의 탈취 위협
            - Refresh Token은 서버와 클라이언트 간에 전송될 때 탈취당할 위험이 있다.
            - 예시:
               - Man-in-the-Middle (MITM) 공격: 네트워크에서 공격자가 Refresh Token을 가로챌 수 있음
               - 네트워크 스니핑: 안전하지 않은 HTTP 요청을 통해 토큰이 노출될 수 있음

         2) 로컬 컴퓨터 탈취는 저장 방식에 따라 보호 가능
            - Refresh Token은 클라이언트 측에서 안전한 저장소에 보관된다:
               - 웹 브라우저: Secure, HttpOnly Cookie 사용
               - 모바일 앱: OS의 Keychain(Android: EncryptedSharedPreferences, iOS: Keychain) 사용

 

      2. Refresh Token 사용 시 보안 강화 방법

         - Refresh Token은 Access Token보다 더 높은 보안 수준으로 보호된다. 이를 위해 다음과 같은 방법들이 사용된다:
         1) Secure Storage
            - 클라이언트 측 보안: Refresh Token은 브라우저의 Secure Cookie나 모바일 앱의 Secure Storage에 저장한다.
            - 이렇게 하면 Refresh Token이 스크립트나 브라우저 개발자 도구를 통해 노출되지 않는다.

         2) Refresh Token Rotation (토큰 교체 방식)
            - 매번 새로운 Access Token과 Refresh Token을 발급할 때 기존 Refresh Token을 무효화한다.
            - 탈취된 Refresh Token이 공격자에 의해 사용되더라도, 원래 사용자가 새 토큰을 발급받으면 공격자가 가진 토큰은 더 이상 쓸 수 없다.

         3) IP 및 디바이스 검증
            - Refresh Token을 사용할 때, 요청이 동일한 IP 주소 또는 디바이스에서 오는지 확인한다.
            - 다른 디바이스나 의심스러운 IP에서 요청이 오면 차단하거나 추가 인증(예: 2단계 인증)을 요구한다.

         4) 짧은 Refresh Token 만료 + 사용자 재인증
            - Refresh Token도 너무 긴 만료 시간을 주지 않고, 일정 시간이 지나면 사용자가 다시 로그인하도록 요구한다.
            - 예: Refresh Token 만료 기간을 30일로 설정하고, 30일 후에는 재인증 필요

         5) 로그아웃 및 세션 종료
            - 서버에서 Refresh Token을 관리하면 특정 사용자가 로그아웃하거나 세션 종료 시 Refresh Token을 강제로 무효화할 수 있다.

 

   5) 강의 내 JWT 토큰을 이용한 발행 및 검증을 개선한 코드

      - 라이브러리 설치

pip install python-multipart
pip install "python-jose[cryptography]"

 

      - SECRET_KEY 생성

$ openssl rand -hex 32
4ab2fce7a6bd79e1c014396315ed322dd6edb1c5d975c6b74a2904135172c03c
>>> import secrets
>>> secrets.token_hex(32)
'344a451d26d1968c0cd4ca12613e5f61b0f71dafced442c730edba55bb9035bc'

 

      - 개선된 코드

from datetime import timedelta, datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette import status
from database import SessionLocal
from models import Users
from passlib.context import CryptContext
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
from jose import jwt, JWTError

router = APIRouter(
    prefix='/auth',
    tags=['auth']
)

# JWT를 위한 비밀 키와 알고리즘 정의
SECRET_KEY = '197b2c37c391bed93fe80344fe73b806947a65e36206e05a1a23c2fa12702fe3'
ALGORITHM = 'HS256'

# 비밀번호 해싱 및 검증을 위한 bcrypt 설정
bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')

# 토큰 인증 관리를 위한 OAuth2PasswordBearer 인스턴스 생성
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='auth/token')

# 사용자 생성 요청을 위한 Pydantic 모델
class CreateUserRequest(BaseModel):
    username: str
    email: str
    first_name: str
    last_name: str
    password: str
    role: str
    phone_number: str

# 토큰 응답을 위한 Pydantic 모델
class Token(BaseModel):
    access_token: str
    token_type: str

# 블랙리스트 관리를 위한 임시 저장소
# 실제 운영 환경에서는 Redis와 같은 외부 저장소를 사용하는 것이 좋습니다.
TOKEN_BLACKLIST = set()

# 데이터베이스 세션 관리를 위한 의존성

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 의존성 주입을 위한 Annotated 타입
db_dependency = Annotated[Session, Depends(get_db)]

# 사용자 이름과 비밀번호로 사용자 인증
def authenticate_user(username: str, password: str, db):
    user = db.query(Users).filter(Users.username == username).first()
    if not user:
        return False
    if not bcrypt_context.verify(password, user.hashed_password):
        return False
    return user

# JWT 액세스 토큰 생성 함수
def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta):
    # 사용자 정보와 만료 시간을 포함한 페이로드 정의
    encode = {'sub': username, 'id': user_id, 'role': role}
    expires = datetime.now(timezone.utc) + expires_delta
    encode.update({'exp': expires})
    # 비밀 키와 알고리즘을 사용하여 페이로드를 인코딩
    return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM)

# 토큰에서 사용자 정보를 추출하고 검증하는 의존성
async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]):
    if token in TOKEN_BLACKLIST:
        # 블랙리스트에 포함된 토큰은 무효화 처리
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='토큰이 무효화되었습니다.')
    try:
        # JWT 토큰을 디코딩하여 사용자 정보 추출
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get('sub')
        user_id: int = payload.get('id')
        user_role: str = payload.get('role')
        if username is None or user_id is None:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
                                detail='사용자를 인증할 수 없습니다.')
        return {'username': username, 'id': user_id, 'user_role': user_role}
    except JWTError:
        # 토큰 디코딩 오류 처리
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
                            detail='사용자를 인증할 수 없습니다.')

# 새로운 사용자 생성 엔드포인트
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_user(db: db_dependency,
                      create_user_request: CreateUserRequest):
    # 사용자의 비밀번호를 해싱하고 새 사용자 인스턴스 생성
    create_user_model = Users(
        email=create_user_request.email,
        username=create_user_request.username,
        first_name=create_user_request.first_name,
        last_name=create_user_request.last_name,
        role=create_user_request.role,
        hashed_password=bcrypt_context.hash(create_user_request.password),
        is_active=True,
        phone_number=create_user_request.phone_number
    )

    # 사용자를 데이터베이스에 추가하고 트랜잭션 커밋
    db.add(create_user_model)
    db.commit()

# 인증된 사용자에게 JWT 토큰을 발급하는 엔드포인트
@router.post("/token", response_model=Token)
async def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
                                 db: db_dependency):
    # 제공된 자격 증명으로 사용자 인증
    user = authenticate_user(form_data.username, form_data.password, db)
    if not user:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
                            detail='사용자를 인증할 수 없습니다.')
    # 20분 동안 유효한 JWT 토큰 생성
    token = create_access_token(user.username, user.id, user.role, timedelta(minutes=20))

    # 응답에 토큰 반환
    return {'access_token': token, 'token_type': 'bearer'}

# 토큰 폐기 또는 블랙리스트 추가 엔드포인트
@router.post("/blacklist")
async def logout_user(token: Annotated[str, Depends(oauth2_bearer)]):
    # 블랙리스트에 토큰 추가
    TOKEN_BLACKLIST.add(token)
    return {"detail": "토큰이 성공적으로 무효화되었습니다."}

# 개선 사항:
# 1. SECRET_KEY는 보안을 위해 환경 변수에서 로드해야 합니다.
# 2. 블랙리스트 저장소로 Redis와 같은 외부 캐시 시스템을 사용하는 것이 권장됩니다.
# 3. 인증 및 사용자 생성 작업에 대한 로깅을 추가하여 모니터링을 강화합니다.
# 4. 새 사용자 등록 시 이메일 인증을 추가하여 보안을 강화합니다.

 

2. Large Production Database Setup

   1) PostgreSQL Connect to FastAPI

      - 라이브러리 설치

pip install psycopg2-binary

 

      - 기본 설정

         - postgresql 설정시

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

# SQLALCHEMY_DATABASE_URL = 'sqlite:///./todosapp.db'
SQLALCHEMY_DATABASE_URL = "postgresql://fastapi:test1324@localhost:5433/fastapi_cc"

engine = create_engine(SQLALCHEMY_DATABASE_URL)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

 

         - mysql 설정시

DATABASE_CONN = "mysql://root:root1234@localhost:3306/blog_db"

 

         - mysql 비동기 설정시

DATABASE_CONN = "mysql+aiomysql://root:root1234@localhost:3306/blog_db"

 

   2) ORM 사용 방법 - 역참조 방법

      - 예제: 사용자와 게시물 모델

         - 사용자(User)와 게시물(Post)의 관계에서 한 사용자가 여러 게시물을 작성할 수 있다고 가정한다. 이 경우, 역참조를 설정하면 사용자는 자신의 게시물 목록을 조회할 수 있고, 게시물은 작성자를 참조할 수 있다.

         - back_populates: 양방향 관계를 명시적으로 설정한다.

 

         - 데이터베이스 모델 정의 :

from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)

    # 역참조: Post 모델과의 관계
    posts = relationship("Post", back_populates="owner")


class Post(Base):
    __tablename__ = "posts"
    id = Column(Integer, primary_key=True, index=True)
    title = Column(String, index=True)
    content = Column(String)
    owner_id = Column(Integer, ForeignKey("users.id"))

    # 역참조: User 모델과의 관계
    owner = relationship("User", back_populates="posts")

 

         - 데이터베이스 사용 예제 :

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

# 데이터베이스 연결
DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 데이터베이스 테이블 생성
Base.metadata.create_all(bind=engine)

# 세션 생성
db = SessionLocal()

# 데이터 추가
new_user = User(name="John Doe")
db.add(new_user)
db.commit()
db.refresh(new_user)

new_post = Post(title="My First Post", content="Hello, world!", owner_id=new_user.id)
db.add(new_post)
db.commit()
db.refresh(new_post)

# 역참조로 데이터 접근
user = db.query(User).filter(User.id == new_user.id).first()
print(user.posts)  # User가 작성한 Post 리스트 출력

post = db.query(Post).filter(Post.id == new_post.id).first()
print(post.owner)  # Post의 작성자(User) 출력

 

   3) ORM 사용 방법 - M:N 관계를 secondary를 이용해 구현하기

      - M:N 관계는 두 테이블 간의 다대다 관계를 의미한다.

         - 이 관계를 처리하려면 보통 연결 테이블(association table)을 사용해야 한다. 이를 위해 SQLAlchemy에서는 별도의 테이블을 만들어 중간 역할을 하도록 설정한다.

         - secondary를 통해 중간 테이블을 지정한다.

      - 다음은 두 테이블 Student와 Course 사이의 M:N 관계를 설정한 예시이다. 중간 테이블로 student_course를 만들어 다대다 관계를 해결한다.

 

      1. 테이블 정의 : 

from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

# 중간 테이블을 위한 모델 정의
class StudentCourse(Base):
    __tablename__ = 'student_course'
    student_id = Column(Integer, ForeignKey('students.id'), primary_key=True)
    course_id = Column(Integer, ForeignKey('courses.id'), primary_key=True)

# Student 테이블
class Student(Base):
    __tablename__ = 'students'
    
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    
    # 다대다 관계 설정 (연결 테이블을 통해 Course와 관계 설정)
    courses = relationship('Course', secondary='student_course', back_populates='students')

# Course 테이블
class Course(Base):
    __tablename__ = 'courses'
    
    id = Column(Integer, primary_key=True, index=True)
    title = Column(String, index=True)
    
    # 다대다 관계 설정 (연결 테이블을 통해 Student와 관계 설정)
    students = relationship('Student', secondary='student_course', back_populates='courses')

 

      2. 학생이 수강하는 강좌 조회 :

# 학생의 수강 과목을 조회
student = session.query(Student).filter(Student.id == 1).first()
print(student.name)
for course in student.courses:
    print(course.title)

 

      3. 강좌를 수강하는 학생들 조회 :

# 강좌에 등록된 학생들을 조회
course = session.query(Course).filter(Course.id == 1).first()
print(course.title)
for student in course.students:
    print(student.name)

 

   4) ORM 사용 방법 - joinedload(django의 selected_prefetch)과 subqueryload(django의 related_prefetch)

      1. joinedload

         - joinedload는 SQL의 JOIN을 사용하여 관련된 데이터를 한 번의 쿼리로 로드한다.
         - 기본적으로, 두 테이블 간의 관계를 INNER JOIN 또는 LEFT OUTER JOIN을 통해 결합하여 데이터를 조회한다.
         - 여러 엔티티가 연결되어 있을 때 성능상 이점이 있을 수 있지만, 테이블이 매우 크거나 복잡한 관계에서는 불필요한 데이터 중복이 발생할 수 있다.

from sqlalchemy.orm import joinedload

query = session.query(Parent).options(joinedload(Parent.children))
results = query.all()

 

      2. subqueryload

         - subqueryload는 서브쿼리를 사용하여 관련된 데이터를 별도의 쿼리로 로드한다.
         - JOIN을 사용하지 않고, 관련된 데이터를 로드할 때 서브쿼리를 생성하여 이를 통해 자식 엔티티를 조회한다.
         - subqueryload는 복잡한 쿼리를 발생시킬 수 있지만, JOIN에서 발생할 수 있는 중복 데이터를 방지할 수 있다.

from sqlalchemy.orm import subqueryload

query = session.query(Parent).options(subqueryload(Parent.children))
results = query.all()

 

      3. 차이점

         - joinedload: 관계된 엔티티들을 하나의 쿼리에서 JOIN으로 가져옵니다. 이 방법은 데이터 중복이 발생할 수 있지만, 더 빠른 쿼리 실행 속도를 제공할 수 있다.
         - subqueryload: 관련된 데이터를 서브쿼리로 로드하여 중복을 방지하고, 복잡한 데이터가 많을 때 더 효율적으로 처리할 수 있다.

항목 joinedload subqueryload
쿼리 수 1개 여러 개 (단계별로 수행)
성능 작은 데이터셋에서 효율적 큰 데이터셋에서 중복을 줄이는 데 유리
데이터 중복 중복 가능 중복 없음

 

      4. 쇼핑몰 데이터베이스 구조에서의 joinedload와 subqueryload의 사용 예시

         - 테이블 관계
            - User (사용자 테이블)
               - 사용자의 기본 정보를 저장합니다.
            - Order (주문 테이블)
               - 각 사용자가 생성한 주문을 저장합니다.
            - Product (상품 테이블)
               - 각 주문에서 포함된 상품을 저장합니다.


         - 테이블 정의

from sqlalchemy import Column, Integer, String, ForeignKey, create_engine
from sqlalchemy.orm import relationship, sessionmaker, declarative_base, joinedload, subqueryload

Base = declarative_base()

# 사용자 테이블
class User(Base):
    __tablename__ = 'users'
    id = Column(Integer, primary_key=True)
    name = Column(String)
    orders = relationship('Order', back_populates='user')  # User와 Order의 관계 설정

# 주문 테이블
class Order(Base):
    __tablename__ = 'orders'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.id'))
    user = relationship('User', back_populates='orders')  # Order와 User의 관계 설정
    products = relationship('Product', back_populates='order')  # Order와 Product의 관계 설정

# 상품 테이블
class Product(Base):
    __tablename__ = 'products'
    id = Column(Integer, primary_key=True)
    order_id = Column(Integer, ForeignKey('orders.id'))
    name = Column(String)
    price = Column(Integer)
    order = relationship('Order', back_populates='products')  # Product와 Order의 관계 설정

# DB 및 세션 초기화
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()

 

         - 데이터 삽입

# 샘플 데이터 삽입
user1 = User(name="Alice")
user2 = User(name="Bob")

order1 = Order(user=user1)
order2 = Order(user=user2)

product1 = Product(name="Laptop", price=1000, order=order1)
product2 = Product(name="Mouse", price=50, order=order1)
product3 = Product(name="Keyboard", price=100, order=order2)

session.add_all([user1, user2, order1, order2, product1, product2, product3])
session.commit()

 

         - joinedload 예제

            1. 사용자와 주문, 상품을 joinedload로 한 번에 가져오기

# joinedload를 사용하여 User, Order, Product 데이터를 한 번의 쿼리로 가져오기
users = session.query(User).options(joinedload(User.orders).joinedload(Order.products)).all()

# 데이터 출력
for user in users:
    print(f"User: {user.name}")
    for order in user.orders:
        print(f"  Order ID: {order.id}")
        for product in order.products:
            print(f"    Product: {product.name}, Price: {product.price}")
SELECT users.id AS users_id, users.name AS users_name, 
       orders.id AS orders_id, orders.user_id AS orders_user_id, 
       products.id AS products_id, products.order_id AS products_order_id, products.name AS products_name, products.price AS products_price
FROM users
LEFT OUTER JOIN orders ON users.id = orders.user_id
LEFT OUTER JOIN products ON orders.id = products.order_id
User: Alice
  Order ID: 1
    Product: Laptop, Price: 1000
    Product: Mouse, Price: 50
User: Bob
  Order ID: 2
    Product: Keyboard, Price: 100

 

            2. 특정 사용자(Bob)의 모든 상품 가져오기

# 특정 사용자의 상품 데이터를 가져오기
bob = session.query(User).options(joinedload(User.orders).joinedload(Order.products)).filter(User.name == "Bob").first()

print(f"User: {bob.name}")
for order in bob.orders:
    for product in order.products:
        print(f"Product: {product.name}, Price: {product.price}")
SELECT users.id AS users_id, users.name AS users_name, 
       orders.id AS orders_id, orders.user_id AS orders_user_id, 
       products.id AS products_id, products.order_id AS products_order_id, products.name AS products_name, products.price AS products_price
FROM users
LEFT OUTER JOIN orders ON users.id = orders.user_id
LEFT OUTER JOIN products ON orders.id = products.order_id
WHERE users.name = 'Bob'
User: Bob
Product: Keyboard, Price: 100

 

         - subqueryload 예제

            1. 사용자와 관련된 주문 및 상품을 subqueryload로 가져오기

# subqueryload를 사용하여 데이터를 가져오기
users = session.query(User).options(subqueryload(User.orders).subqueryload(Order.products)).all()

# 데이터 출력
for user in users:
    print(f"User: {user.name}")
    for order in user.orders:
        print(f"  Order ID: {order.id}")
        for product in order.products:
            print(f"    Product: {product.name}, Price: {product.price}")

 

            - 첫 번째 쿼리 (사용자 조회)

SELECT users.id AS users_id, users.name AS users_name 
FROM users

 

            - 두 번째 쿼리 (주문 조회)

SELECT orders.id AS orders_id, orders.user_id AS orders_user_id 
FROM orders 
WHERE orders.user_id IN (1, 2)

 

            - 세 번째 쿼리 (상품 조회)

SELECT products.id AS products_id, products.order_id AS products_order_id, products.name AS products_name, products.price AS products_price
FROM products 
WHERE products.order_id IN (1, 2)
User: Alice
  Order ID: 1
    Product: Laptop, Price: 1000
    Product: Mouse, Price: 50
User: Bob
  Order ID: 2
    Product: Keyboard, Price: 100

 

            2. 특정 주문의 상품 데이터 가져오기

# 특정 주문의 상품 데이터 가져오기
order = session.query(Order).options(subqueryload(Order.products)).filter(Order.id == 1).first()

print(f"Order ID: {order.id}")
for product in order.products:
    print(f"Product: {product.name}, Price: {product.price}")

 

            - 첫 번째 쿼리 (주문 조회)

SELECT orders.id AS orders_id, orders.user_id AS orders_user_id 
FROM orders 
WHERE orders.id = 1

 

            - 두 번째 쿼리 (상품 조회)

SELECT products.id AS products_id, products.order_id AS products_order_id, products.name AS products_name, products.price AS products_price
FROM products 
WHERE products.order_id = 1
Order ID: 1
Product: Laptop, Price: 1000
Product: Mouse, Price: 50

 

3. Alembic Data Migration

   1) 설치

pip install alembic

 

   2) 명령어 정리

명령어 설명
alembic init <directory> 지정된 디렉토리에 새로운 Alembic 환경을 초기화합니다. 필요한 파일과 디렉토리를 생성합니다.
alembic revision -m "<message>" 지정된 메시지를 가진 새로운 마이그레이션 스크립트를 생성합니다. 모델에서 감지된 변경 사항을 포함합니다.
alembic upgrade <revision> 데이터베이스 스키마를 지정된 버전으로 업그레이드합니다. 버전이 지정되지 않으면 최신 버전으로 업그레이드합니다.
alembic downgrade <revision> 데이터베이스 스키마를 지정된 버전으로 다운그레이드합니다. 버전 번호나 상대적인 변경 사항 (예: -1로 이전 버전) 사용 가능합니다.
alembic history 모든 마이그레이션의 리스트를 표시하고, 각 변경 사항과 버전 식별자를 보여줍니다.
alembic current 현재 데이터베이스 스키마의 버전 (또는 revision)을 표시합니다.
alembic show <revision> 특정 버전의 마이그레이션 파일에 대한 자세한 변경 사항을 표시합니다.
alembic stamp <revision> 마이그레이션 스크립트를 실행하지 않고 지정된 버전으로 데이터베이스를 마킹합니다.
alembic merge -m "<message>" <revision1> <revision2> 두 개 이상의 마이그레이션을 하나로 합칩니다. 마이그레이션 히스토리에서 분기된 변경 사항을 결합할 때 사용됩니다.
alembic upgrade head 마이그레이션 히스토리에서 최신 버전으로 업그레이드합니다.
alembic downgrade base 데이터베이스를 초기 상태 (기본 버전)로 다운그레이드합니다.
alembic revision --autogenerate -m "<message>" 모델의 변경 사항을 자동으로 감지하여 마이그레이션 스크립트를 생성합니다.

 

   3) 초기화

      - Alembic 프로젝트를 설정하려면 init 명령어를 사용한다.

alembic init alembic

 

      - 위 명령은 다음 디렉토리 구조를 생성한다.

         - env.py: Alembic 환경 설정 파일로, 데이터베이스 연결과 마이그레이션 로직이 정의된다.
         - versions/: 마이그레이션 파일이 저장되는 디렉토리입이다.
         - alembic.ini: Alembic 설정 파일이다.

alembic/
  env.py
  README
  script.py.mako
  versions/
alembic.ini

 

   3) 데이터베이스 연결 설정

      - alembic.ini 파일에서 데이터베이스 URL을 설정

[alembic]
sqlalchemy.url = sqlite:///./test.db
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
from myapp.models import Base  # SQLAlchemy 모델의 Base를 가져옴

# Alembic 설정 파일
config = context.config

# 로그 설정
fileConfig(config.config_file_name)

# SQLAlchemy 모델의 메타데이터 등록
target_metadata = Base.metadata

def run_migrations_offline():
    """오프라인 모드에서 마이그레이션 실행"""
    url = config.get_main_option("sqlalchemy.url")
    context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
    with context.begin_transaction():
        context.run_migrations()

def run_migrations_online():
    """온라인 모드에서 마이그레이션 실행"""
    connectable = engine_from_config(
        config.get_section(config.config_ini_section),
        prefix="sqlalchemy.",
        poolclass=pool.NullPool,
    )
    with connectable.connect() as connection:
        context.configure(connection=connection, target_metadata=target_metadata)
        with context.begin_transaction():
            context.run_migrations()

if context.is_offline_mode():
    run_migrations_offline()
else:
    run_migrations_online()

 

   4) Alembic 명령어

      1. 마이그레이션 스크립트 생성

         - 자동 생성

            - versions/ 디렉토리에 새로운 파일이 생성된다.
            - target_metadata에서 정의한 모델을 기준으로 변경 사항을 감지한다.

alembic revision --autogenerate -m "add user table"

 

         - 자동 생성된 스크립트

"""add user table

Revision ID: abc123
Revises: 
Create Date: 2025-01-10 10:00:00

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'abc123'
down_revision = None
branch_labels = None
depends_on = None


def upgrade():
    op.create_table(
        'user',
        sa.Column('id', sa.Integer(), primary_key=True),
        sa.Column('name', sa.String(), nullable=False),
        sa.Column('email', sa.String(), unique=True, nullable=False),
    )


def downgrade():
    op.drop_table('user')

 

         - 수동 생성

alembic revision -m "custom migration"

 

         - 생성된 파일에 코드를 작성

"""custom migration

Revision ID: def456
Revises: abc123
Create Date: 2025-01-10 10:15:00

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'def456'
down_revision = 'abc123'
branch_labels = None
depends_on = None


def upgrade():
    # 기존 테이블에 컬럼 추가
    op.add_column('user', sa.Column('age', sa.Integer(), nullable=True))


def downgrade():
    # 추가된 컬럼 삭제
    op.drop_column('user', 'age')

 

      2. 마이그레이션 적용

         - 마이그레이션 업그레이드

            - head는 가장 최신 리비전으로 이동

alembic upgrade head

 

             - 특정 리비전으로 업그레이드

alembic upgrade abc123

 

         - 마이그레이션 다운그레이드

            - 이전 리비전으로 이동

alembic downgrade -1

 

         - 특정 리비전으로 다운그레이드

alembic downgrade base

 

      3. 마이그레이션 예시

         - 이 마이그레이션 스크립트는 데이터베이스의 'user' 테이블에 'status'라는 새로운 컬럼을 추가하고, 기존 사용자들의 상태를 'active'로 설정하는 용도로 사용

 

            - upgrade 함수: 데이터베이스 스키마를 업그레이드하는 로직을 포함합니다.
               - op.add_column: 'user' 테이블에 'status'라는 새로운 문자열 컬럼을 추가한다. nullable=True는 이 컬럼이 NULL 값을 허용함을 의미한다.
               - op.execute: SQL 쿼리를 실행하여 기존의 모든 사용자 행에 대해 'status' 컬럼을 'active'로 업데이트한다.


            - downgrade 함수: 업그레이드를 되돌리는 로직을 포함한다.

               - op.drop_column: 'user' 테이블에서 'status' 컬럼을 삭제

"""Add status column to user table

Revision ID: <your_revision_id>
Revises: <previous_revision_id>
Create Date: <current_date>
"""

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = '<your_revision_id>'
down_revision = '<previous_revision_id>'
branch_labels = None
depends_on = None

def upgrade():
    """Apply the upgrade to the database schema."""
    # 'user' 테이블에 'status'라는 새로운 문자열 컬럼을 추가합니다.
    op.add_column('user', sa.Column('status', sa.String(), nullable=True))
    
    # 기존 'user' 테이블의 모든 행에 대해 'status' 컬럼을 'active'로 업데이트합니다.
    op.execute("UPDATE user SET status = 'active'")

def downgrade():
    """Revert the changes made in the upgrade."""
    # 'user' 테이블에서 'status' 컬럼을 삭제합니다.
    op.drop_column('user', 'status')

 

         - 이 마이그레이션 스크립트는 데이터베이스에 order 테이블을 생성하고, 기존 사용자들의 ID를 해당 테이블에 삽입하는 데 사용

 

         - upgrade 함수: 데이터베이스 스키마를 업그레이드하는 로직을 포함한다.
            - op.create_table: 'order'라는 새로운 테이블을 생성한다.
            - sa.Column('id', sa.Integer(), primary_key=True): 기본 키로 사용되는 정수형 'id' 컬럼을 정의한다.
            - sa.Column('user_id', sa.Integer(), sa.ForeignKey('user.id')): 'user' 테이블의 'id'를 참조하는 외래 키 'user_id' 컬럼을 정의한다.
            - op.execute: SQL 쿼리를 실행하여 기존의 모든 사용자 ID를 'order' 테이블에 삽입한다.

 

         -  downgrade 함수: 업그레이드를 되돌리는 로직을 포함한다.
            - op.drop_table: 'order' 테이블을 삭제하여 업그레이드를 되돌린다.

"""Create order table

Revision ID: <your_revision_id>
Revises: <previous_revision_id>
Create Date: <current_date>
"""

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = '<your_revision_id>'
down_revision = '<previous_revision_id>'
branch_labels = None
depends_on = None

def upgrade():
    """Apply the upgrade to the database schema."""
    # 'order'라는 새로운 테이블을 생성합니다.
    op.create_table(
        'order',
        # 'id' 컬럼: 기본 키로 사용되는 정수형 컬럼
        sa.Column('id', sa.Integer(), primary_key=True),
        # 'user_id' 컬럼: 'user' 테이블의 'id'를 참조하는 외래 키
        sa.Column('user_id', sa.Integer(), sa.ForeignKey('user.id')),
    )
    
    # 기존 'user' 테이블의 모든 사용자 ID를 'order' 테이블에 삽입합니다.
    op.execute("INSERT INTO order (user_id) SELECT id FROM user")

def downgrade():
    """Revert the changes made in the upgrade."""
    # 'order' 테이블을 삭제하여 업그레이드를 되돌립니다.
    op.drop_table('order')

댓글