diff --git a/monkey/monkey_island/cc/resources/propagation_credentials.py b/monkey/monkey_island/cc/resources/propagation_credentials.py index e60478930..97e8ecc3c 100644 --- a/monkey/monkey_island/cc/resources/propagation_credentials.py +++ b/monkey/monkey_island/cc/resources/propagation_credentials.py @@ -6,47 +6,44 @@ from common.credentials import Credentials from monkey_island.cc.repository import ICredentialsRepository from monkey_island.cc.resources.AbstractResource import AbstractResource +_configured_collection = "configured-credentials" +_stolen_collection = "stolen-credentials" + class PropagationCredentials(AbstractResource): - urls = [ - "/api/propagation-credentials", - "/api/propagation-credentials/configured-credentials", - "/api/propagation-credentials/stolen-credentials", - ] + urls = ["/api/propagation-credentials/", "/api/propagation-credentials/"] def __init__(self, credentials_repository: ICredentialsRepository): self._credentials_repository = credentials_repository - def get(self): - propagation_credentials = [] - - if request.url.endswith("/configured-credentials"): + def get(self, collection=None): + if collection == _configured_collection: propagation_credentials = self._credentials_repository.get_configured_credentials() - elif request.url.endswith("/stolen-credentials"): + elif collection == _stolen_collection: propagation_credentials = self._credentials_repository.get_stolen_credentials() else: propagation_credentials = self._credentials_repository.get_all_credentials() return make_response(Credentials.to_json_array(propagation_credentials), HTTPStatus.OK) - def post(self): + def post(self, collection=None): credentials = [Credentials.from_json(c) for c in request.json] - if request.url.endswith("/configured-credentials"): + if collection == _configured_collection: self._credentials_repository.save_configured_credentials(credentials) - elif request.url.endswith("/stolen-credentials"): + elif collection == _stolen_collection: self._credentials_repository.save_stolen_credentials(credentials) else: return {}, HTTPStatus.METHOD_NOT_ALLOWED return {}, HTTPStatus.NO_CONTENT - def delete(self): - if request.url.endswith("/configured-credentials"): + def delete(self, collection=None): + if collection == _configured_collection: self._credentials_repository.remove_configured_credentials() - elif request.url.endswith("/stolen-credentials"): + elif collection == _stolen_collection: self._credentials_repository.remove_stolen_credentials() else: - return {}, HTTPStatus.METHOD_NOT_ALLOWED + self._credentials_repository.remove_all_credentials() return {}, HTTPStatus.NO_CONTENT diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py index d34a45a25..f717fd032 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py @@ -1,6 +1,7 @@ import json from http import HTTPStatus from typing import Sequence +from urllib.parse import urljoin import pytest from tests.common import StubDIContainer @@ -15,10 +16,14 @@ from tests.monkey_island import InMemoryCredentialsRepository from common.credentials import Credentials from monkey_island.cc.repository import ICredentialsRepository from monkey_island.cc.resources import PropagationCredentials +from monkey_island.cc.resources.propagation_credentials import ( + _configured_collection, + _stolen_collection, +) ALL_CREDENTIALS_URL = PropagationCredentials.urls[0] -CONFIGURED_CREDENTIALS_URL = PropagationCredentials.urls[1] -STOLEN_CREDENTIALS_URL = PropagationCredentials.urls[2] +CONFIGURED_CREDENTIALS_URL = urljoin(ALL_CREDENTIALS_URL, _configured_collection) +STOLEN_CREDENTIALS_URL = urljoin(ALL_CREDENTIALS_URL, _stolen_collection) @pytest.fixture @@ -117,8 +122,3 @@ def test_stolen_propagation_credentials_endpoint_delete(flask_client, credential def test_propagation_credentials_endpoint__propagation_credentials_post_not_allowed(flask_client): resp = flask_client.post(ALL_CREDENTIALS_URL, json=[]) assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED - - -def test_propagation_credentials_endpoint__propagation_credentials_delete_not_allowed(flask_client): - resp = flask_client.delete(ALL_CREDENTIALS_URL) - assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED