Island: Use collection url placeholder for propagation credentials

This commit is contained in:
vakarisz 2022-07-12 12:39:16 +03:00
parent 037e6619dd
commit a9e2dd2d3d
2 changed files with 21 additions and 24 deletions

View File

@ -6,47 +6,44 @@ from common.credentials import Credentials
from monkey_island.cc.repository import ICredentialsRepository from monkey_island.cc.repository import ICredentialsRepository
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
_configured_collection = "configured-credentials"
_stolen_collection = "stolen-credentials"
class PropagationCredentials(AbstractResource): class PropagationCredentials(AbstractResource):
urls = [ urls = ["/api/propagation-credentials/", "/api/propagation-credentials/<string:collection>"]
"/api/propagation-credentials",
"/api/propagation-credentials/configured-credentials",
"/api/propagation-credentials/stolen-credentials",
]
def __init__(self, credentials_repository: ICredentialsRepository): def __init__(self, credentials_repository: ICredentialsRepository):
self._credentials_repository = credentials_repository self._credentials_repository = credentials_repository
def get(self): def get(self, collection=None):
propagation_credentials = [] if collection == _configured_collection:
if request.url.endswith("/configured-credentials"):
propagation_credentials = self._credentials_repository.get_configured_credentials() propagation_credentials = self._credentials_repository.get_configured_credentials()
elif request.url.endswith("/stolen-credentials"): elif collection == _stolen_collection:
propagation_credentials = self._credentials_repository.get_stolen_credentials() propagation_credentials = self._credentials_repository.get_stolen_credentials()
else: else:
propagation_credentials = self._credentials_repository.get_all_credentials() propagation_credentials = self._credentials_repository.get_all_credentials()
return make_response(Credentials.to_json_array(propagation_credentials), HTTPStatus.OK) return make_response(Credentials.to_json_array(propagation_credentials), HTTPStatus.OK)
def post(self): def post(self, collection=None):
credentials = [Credentials.from_json(c) for c in request.json] credentials = [Credentials.from_json(c) for c in request.json]
if request.url.endswith("/configured-credentials"): if collection == _configured_collection:
self._credentials_repository.save_configured_credentials(credentials) self._credentials_repository.save_configured_credentials(credentials)
elif request.url.endswith("/stolen-credentials"): elif collection == _stolen_collection:
self._credentials_repository.save_stolen_credentials(credentials) self._credentials_repository.save_stolen_credentials(credentials)
else: else:
return {}, HTTPStatus.METHOD_NOT_ALLOWED return {}, HTTPStatus.METHOD_NOT_ALLOWED
return {}, HTTPStatus.NO_CONTENT return {}, HTTPStatus.NO_CONTENT
def delete(self): def delete(self, collection=None):
if request.url.endswith("/configured-credentials"): if collection == _configured_collection:
self._credentials_repository.remove_configured_credentials() self._credentials_repository.remove_configured_credentials()
elif request.url.endswith("/stolen-credentials"): elif collection == _stolen_collection:
self._credentials_repository.remove_stolen_credentials() self._credentials_repository.remove_stolen_credentials()
else: else:
return {}, HTTPStatus.METHOD_NOT_ALLOWED self._credentials_repository.remove_all_credentials()
return {}, HTTPStatus.NO_CONTENT return {}, HTTPStatus.NO_CONTENT

View File

@ -1,6 +1,7 @@
import json import json
from http import HTTPStatus from http import HTTPStatus
from typing import Sequence from typing import Sequence
from urllib.parse import urljoin
import pytest import pytest
from tests.common import StubDIContainer from tests.common import StubDIContainer
@ -15,10 +16,14 @@ from tests.monkey_island import InMemoryCredentialsRepository
from common.credentials import Credentials from common.credentials import Credentials
from monkey_island.cc.repository import ICredentialsRepository from monkey_island.cc.repository import ICredentialsRepository
from monkey_island.cc.resources import PropagationCredentials from monkey_island.cc.resources import PropagationCredentials
from monkey_island.cc.resources.propagation_credentials import (
_configured_collection,
_stolen_collection,
)
ALL_CREDENTIALS_URL = PropagationCredentials.urls[0] ALL_CREDENTIALS_URL = PropagationCredentials.urls[0]
CONFIGURED_CREDENTIALS_URL = PropagationCredentials.urls[1] CONFIGURED_CREDENTIALS_URL = urljoin(ALL_CREDENTIALS_URL, _configured_collection)
STOLEN_CREDENTIALS_URL = PropagationCredentials.urls[2] STOLEN_CREDENTIALS_URL = urljoin(ALL_CREDENTIALS_URL, _stolen_collection)
@pytest.fixture @pytest.fixture
@ -117,8 +122,3 @@ def test_stolen_propagation_credentials_endpoint_delete(flask_client, credential
def test_propagation_credentials_endpoint__propagation_credentials_post_not_allowed(flask_client): def test_propagation_credentials_endpoint__propagation_credentials_post_not_allowed(flask_client):
resp = flask_client.post(ALL_CREDENTIALS_URL, json=[]) resp = flask_client.post(ALL_CREDENTIALS_URL, json=[])
assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED
def test_propagation_credentials_endpoint__propagation_credentials_delete_not_allowed(flask_client):
resp = flask_client.delete(ALL_CREDENTIALS_URL)
assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED