Django REST Framework/DRF 일반

[DRF] TestCase로 Django-Rest-Framework를 위한 테스트 코드 만들기

bluebamus 2024. 1. 23. 22:47

[udemy] Build a Backend REST API with Python & Django - Advanced 학습 정리

 

이전에 학습한 강좌 중 TestCase의 코드 내용이 괜찮았기에 따로 분리하여 포스팅을 한다.

 

 1. 모든 테스트 케이스에 공통적으로 사용되는 코드는 전역변수 혹은 전역 함수로 만들어 사용하자.

CREATE_USER_URL = reverse('user:create')
TOKEN_URL = reverse('user:token')
ME_URL = reverse('user:me')


def create_user(**params):
    """Create and return a new user."""
    return get_user_model().objects.create_user(**params)

 

   - 위 코드를 사용하게 되면, 현재 개발되고 있는 프로젝트의 계획 변경으로 삭제, 수정 등으로 변경이 필요할 경우 쉽게 변경하여 적용할 수 있다.

   - 특히 함수의 경우 다양한 목적이 아닌, 하나의 동작을 위한 목적으로 만들어야 하며, 해당 동작에 변경이 있을 경우 쉽게 적용할 수 있다.

 

 2. setup(self)

   - test 클래스에는 setup과 teardown을 설정할 수 있다.

      - setup은 클래스에 포함된 모든 테스트케이스가 시작에 필요한 데이터 혹은 client, 인증 상태 등이 미리 준비된다.

    def setUp(self):
        self.user = create_user(
            email='test@example.com',
            password='testpass123',
            name='Test Name',
        )
        self.client = APIClient()
        self.client.force_authenticate(user=self.user)

 

      - self에 저장된 객체 변수를 이용해 모든 테스트들은 필요한 데이터 혹은 client를 호출하여 사용할 수 있다.

    def test_retrieve_profile_success(self):
        """Test retrieving profile for logged in user."""
        res = self.client.get(ME_URL)

        self.assertEqual(res.status_code, status.HTTP_200_OK)
        self.assertEqual(res.data, {
            'name': self.user.name,
            'email': self.user.email,
        })

 

 3. command 테스트

   - patch decorator로 check 함수에 대한 hook을 연결한다.

   - patched_check은 해당 함수의 객체를 return 하고, return_value는 return 값을 수동으로 설정한다.

      1. assert_called_once_with(*args, **kwargs): 이 메서드는 mock 객체가 정확히 한 번 호출되었으며, 호출 시 주어진 인자들로 호출되었음을 확인한다. 만약 mock 객체가 한 번 이상 호출되었거나, 호출 시 다른 인자들로 호출되었다면, 이 메서드는 AssertionError를 발생시킨다.

      2. assert_called_with(*args, **kwargs): 이 메서드는 mock 객체가 마지막으로 호출될 때, 주어진 인자들로 호출되었음을 확인한다. 호출 횟수는 확인하지 않는다. 만약 mock 객체가 마지막으로 호출될 때 다른 인자들로 호출되었다면, 이 메서드는 AssertionError를 발생시킨다.

   - assert_called_once_with는 호출 횟수까지 확인하고, assert_called_with는 마지막 호출의 인자만을 확인한다.

"""
Test custom Django management commands.
"""
from unittest.mock import patch

from psycopg2 import OperationalError as Psycopg2OpError

from django.core.management import call_command
from django.db.utils import OperationalError
from django.test import SimpleTestCase


@patch('core.management.commands.wait_for_db.Command.check')
class CommandTests(SimpleTestCase):
    """Test commands."""

    def test_wait_for_db_ready(self, patched_check):
        """Test waiting for database if database ready."""
        patched_check.return_value = True

        call_command('wait_for_db')

        patched_check.assert_called_once_with(databases=['default'])

    @patch('time.sleep')
    def test_wait_for_db_delay(self, patched_sleep, patched_check):
        """Test waiting for database when getting OperationalError."""
        patched_check.side_effect = [Psycopg2OpError] * 2 + \
            [OperationalError] * 3 + [True]

        call_command('wait_for_db')

        self.assertEqual(patched_check.call_count, 6)
        patched_check.assert_called_with(databases=['default'])

 

4. refresh_from_db()의 사용법

   - database에 데이터를 변경을 하는 경우 갱신된 데이터를 가져오기 위해 .refresh_from_db()를 사용한다.

def test_update_ingredient(self):
        """Test updating an ingredient."""
        ingredient = Ingredient.objects.create(user=self.user, name='Cilantro')

        payload = {'name': 'Coriander'}
        url = detail_url(ingredient.id)
        res = self.client.patch(url, payload)

        self.assertEqual(res.status_code, status.HTTP_200_OK)
        ingredient.refresh_from_db()
        self.assertEqual(ingredient.name, payload['name'])

 

 5. 저장한 데이터와 client로 요청받은 데이터의 비교 방법

    def test_create_recipe(self):
        """Test creating a recipe."""
        payload = {
            'title': 'Sample recipe',
            'time_minutes': 30,
            'price': Decimal('5.99'),
        }
        res = self.client.post(RECIPES_URL, payload)

        self.assertEqual(res.status_code, status.HTTP_201_CREATED)
        recipe = Recipe.objects.get(id=res.data['id'])
        for k, v in payload.items():
            self.assertEqual(getattr(recipe, k), v)
        self.assertEqual(recipe.user, self.user)

 

 6. 이미지 업로드 테스트 방법

class ImageUploadTests(TestCase):
    """Tests for the image upload API."""

    def setUp(self):
        self.client = APIClient()
        self.user = get_user_model().objects.create_user(
            'user@example.com',
            'password123',
        )
        self.client.force_authenticate(self.user)
        self.recipe = create_recipe(user=self.user)

    def tearDown(self):
        self.recipe.image.delete()

    def test_upload_image(self):
        """Test uploading an image to a recipe."""
        url = image_upload_url(self.recipe.id)
        with tempfile.NamedTemporaryFile(suffix='.jpg') as image_file:
            img = Image.new('RGB', (10, 10))
            img.save(image_file, format='JPEG')
            image_file.seek(0)
            payload = {'image': image_file}
            res = self.client.post(url, payload, format='multipart')

        self.recipe.refresh_from_db()
        self.assertEqual(res.status_code, status.HTTP_200_OK)
        self.assertIn('image', res.data)
        self.assertTrue(os.path.exists(self.recipe.image.path))

    def test_upload_image_bad_request(self):
        """Test uploading an invalid image."""
        url = image_upload_url(self.recipe.id)
        payload = {'image': 'notanimage'}
        res = self.client.post(url, payload, format='multipart')

        self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST)