diff --git a/marble_api/versions/v1/data_request/models.py b/marble_api/versions/v1/data_request/models.py index 0f39c9f..45b0547 100644 --- a/marble_api/versions/v1/data_request/models.py +++ b/marble_api/versions/v1/data_request/models.py @@ -13,6 +13,7 @@ Field, FieldSerializationInfo, ValidationInfo, + computed_field, field_serializer, field_validator, model_validator, @@ -30,7 +31,7 @@ collapse_geometries, validate_collapsible, ) -from marble_api.utils.models import partial_model +from marble_api.utils.models import object_id, partial_model PyObjectId = Annotated[str, BeforeValidator(str)] Temporal = Annotated[list[AwareDatetime], Field(..., min_length=1, max_length=2), AfterValidator(sorted)] @@ -53,6 +54,7 @@ class DataRequest(BaseModel): id: SkipJsonSchema[PyObjectId | None] = Field(default=None, validation_alias="_id", exclude=True) user: SkipJsonSchema[str | None] = None # user is set by the route after the model is first validated + updated: SkipJsonSchema[AwareDatetime | None] = None # updated is set by the route title: str description: str | None = None authors: list[Author] @@ -97,7 +99,7 @@ def convert_from_utc(self, value: Temporal, info: FieldSerializationInfo) -> lis @field_serializer("user") def require_user_set(self, value: str, info: FieldSerializationInfo) -> str: - """Require that the user_name is set when the model is serialized.""" + """Require that the user name be set when the model is serialized.""" assert value, f"{info.field_name} must be set and non-empty" return value @@ -123,8 +125,14 @@ class DataRequestPublic(DataRequest): id: Annotated[str, BeforeValidator(str)] = Field(..., validation_alias="_id") user: str # user is required to be set in the database + updated: AwareDatetime model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True, extra="allow") + @computed_field + def created(self) -> AwareDatetime: + """Set the created time based on the object id.""" + return object_id(self.id, None).generation_time + @property def stac_item(self) -> Item: """Dynamically create a STAC item representation of this data.""" diff --git a/marble_api/versions/v1/data_request/routes.py b/marble_api/versions/v1/data_request/routes.py index e7f6c31..8d863f0 100644 --- a/marble_api/versions/v1/data_request/routes.py +++ b/marble_api/versions/v1/data_request/routes.py @@ -1,3 +1,4 @@ +import datetime from collections.abc import AsyncGenerator from typing import Annotated @@ -43,6 +44,7 @@ def _is_router_scope(request: Request, router: APIRouter) -> bool: async def post_data_request_user(user: str, data_request: DataRequest) -> DataRequestPublic: """Create a new data request and return the newly created data request.""" data_request.user = user + data_request.updated = datetime.datetime.now(tz=datetime.timezone.utc) new_data_request = data_request.model_dump(by_alias=True) result = await client.db["data-request"].insert_one(new_data_request) new_data_request["id"] = str(result.inserted_id) @@ -63,6 +65,8 @@ async def patch_data_request( if user: data_request.user = user selector = {"_id": _data_request_id(request_id)} + # updated timestamps are handled automatically + updated_fields["updated"] = datetime.datetime.now(tz=datetime.timezone.utc) if updated_fields: result = await client.db["data-request"].find_one_and_update( selector, {"$set": updated_fields}, return_document=ReturnDocument.AFTER diff --git a/pyproject.toml b/pyproject.toml index 57943b9..d1b94ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ "Programming Language :: Python :: 3 :: Only", ] dependencies = [ - "fastapi~=0.115", + "fastapi<0.137", # TODO: handle breaking changes https://github.com/fastapi/fastapi/pull/15745 "pymongo~=4.14", "geojson-pydantic~=2.0", "stac-pydantic~=3.4", diff --git a/test/faker_providers.py b/test/faker_providers.py index 06f343a..513310d 100644 --- a/test/faker_providers.py +++ b/test/faker_providers.py @@ -216,13 +216,14 @@ def asset(self): "type": self.generator.mime_type(), "title": self.generator.word(), "description": self.generator.sentence(), - "roles": self.generator.get_words_list(), + "roles": self.generator.words(nb=self.generator.random.randint(0, 10)), } def _data_request_inputs(self, unset=None): inputs = dict( id=bson.ObjectId(), user=self.generator.profile("username")["username"], + updated=self.generator.tz_aware_date_time_seconds_precision(), title=self.generator.sentence(), description=(None if self.generator.pybool(30) else self.generator.paragraph()), authors=[self.author() for _ in range(self.generator.random.randint(1, 10))], diff --git a/test/integration/versions/v1/data_request/test_routes.py b/test/integration/versions/v1/data_request/test_routes.py index 93301ff..5debf33 100644 --- a/test/integration/versions/v1/data_request/test_routes.py +++ b/test/integration/versions/v1/data_request/test_routes.py @@ -1,3 +1,4 @@ +import datetime import inspect import json from urllib.parse import parse_qs, urlparse @@ -13,6 +14,13 @@ pytestmark = pytest.mark.anyio +def compare_no_timestamps(dict1, dict2, message=""): + timestamps = {"created", "updated"} + assert {k: v for k, v in dict1.items() if k not in timestamps} == { + k: v for k, v in dict2.items() if k not in timestamps + }, message + + class _TestUser: @pytest.fixture def member_route(self, data_requests): @@ -62,7 +70,7 @@ class _TestGetOne(_TestGet): async def test_get(self, async_client, data_requests, member_route): resp = await async_client.get(member_route) assert resp.status_code == 200 - assert DataRequestPublic(**data_requests[0]) == DataRequestPublic(**resp.json()) + assert DataRequestPublic(**data_requests[0]).model_dump() == DataRequestPublic(**resp.json()).model_dump() async def test_get_stac(self, async_client, member_route): resp = await async_client.get(f"{member_route}?stac=true") @@ -95,9 +103,9 @@ class _TestGetMany(_TestGet): async def test_get(self, async_client, data_requests, collection_route): response = await async_client.get(collection_route) - models = {str(req["_id"]): DataRequestPublic(**req) for req in data_requests} + models = {str(req["_id"]): DataRequestPublic(**req).model_dump() for req in data_requests} for req in response.json()["data_requests"]: - assert DataRequestPublic(**req) == models[req["id"]] + assert DataRequestPublic(**req).model_dump() == models[req["id"]] async def test_get_stac(self, async_client, collection_route): resp = await async_client.get(f"{collection_route}?stac=true") @@ -229,7 +237,7 @@ async def test_valid(self, fake, async_client, collection_route, data_requests): response_data = response.json() assert (id_ := response_data.pop("id", None)) bson.ObjectId(id_) # check that the id is a valid object id - assert {"user": data_requests[0]["user"], **json.loads(data)} == response_data + compare_no_timestamps({"user": data_requests[0]["user"], **json.loads(data)}, response_data) async def test_invalid_authors(self, fake, async_client, collection_route): data = json.loads(fake.data_request().model_dump_json()) @@ -276,7 +284,7 @@ async def test_valid(self, loaded_data, async_client, fake, member_route): response = await async_client.patch(member_route, json=update) assert response.status_code == 200 loaded_data.update(update) - assert loaded_data == response.json() + compare_no_timestamps(loaded_data, response.json()) async def test_valid_multiple(self, loaded_data, async_client, fake, member_route): title = fake.sentence() @@ -285,12 +293,12 @@ async def test_valid_multiple(self, loaded_data, async_client, fake, member_rout response = await async_client.patch(member_route, json=update) assert response.status_code == 200 loaded_data.update(update) - assert loaded_data == response.json() + compare_no_timestamps(loaded_data, response.json()) async def test_update_nothing(self, loaded_data, async_client, member_route): response = await async_client.patch(member_route, json={}) assert response.status_code == 200 - assert loaded_data == response.json() + compare_no_timestamps(loaded_data, response.json()) async def test_no_id_update(self, loaded_data, async_client, member_route): update = {"id": str(bson.ObjectId())} @@ -298,7 +306,7 @@ async def test_no_id_update(self, loaded_data, async_client, member_route): assert response.status_code == 200 assert response.json()["id"] == loaded_data["id"] assert response.json()["id"] != update["id"] - assert loaded_data == response.json() + compare_no_timestamps(loaded_data, response.json()) async def test_invalid_unset_value(self, async_client, member_route): response = await async_client.patch(member_route, json={"title": None}) @@ -319,6 +327,27 @@ async def test_bad_id(self, async_client, collection_route): resp = await async_client.patch(f"{collection_route}/id-does-not-exist", json={}) assert resp.status_code == 404, resp.json() + async def test_created_in_response(self, fake, async_client, member_route): + title = fake.sentence() + update = {"title": title} + response = await async_client.patch(member_route, json=update) + assert response.status_code == 200 + assert response.json()["created"] + + async def test_updated_updated(self, loaded_data, fake, async_client, member_route): + title = fake.sentence() + update = {"title": title} + response = await async_client.patch(member_route, json=update) + assert response.status_code == 200 + assert loaded_data["updated"] != response.json()["updated"] + + @pytest.mark.parametrize("field", ["created", "updated"]) + async def test_no_updatable_timestamps(self, loaded_data, async_client, member_route, field): + new_date = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=10) + response = await async_client.patch(member_route, json={field: new_date.isoformat()}) + assert response.status_code == 200 + assert response.json()[field] != new_date.isoformat() + class TestPatchUser(_TestPatch, _TestUser): async def test_update_everything(self, loaded_data, async_client, fake, member_route): @@ -327,7 +356,7 @@ async def test_update_everything(self, loaded_data, async_client, fake, member_r assert response.status_code == 200 update["id"] = loaded_data["id"] update["user"] = loaded_data["user"] - assert update == response.json() + compare_no_timestamps(update, response.json()) async def test_no_update_user(self, loaded_data, async_client, member_route): new_user = loaded_data["user"] + "suffix" @@ -341,7 +370,7 @@ async def test_update_everything(self, loaded_data, async_client, fake, member_r response = await async_client.patch(member_route, json=update) assert response.status_code == 200 update["id"] = loaded_data["id"] - assert update == response.json() + compare_no_timestamps(update, response.json()) async def test_update_user(self, loaded_data, async_client, member_route): new_user = loaded_data["user"] + "suffix" diff --git a/test/unit/versions/v1/data_request/test_models.py b/test/unit/versions/v1/data_request/test_models.py index e038d55..2447430 100644 --- a/test/unit/versions/v1/data_request/test_models.py +++ b/test/unit/versions/v1/data_request/test_models.py @@ -6,6 +6,7 @@ from pystac import Item from marble_api.utils.geojson import collapse_geometries +from marble_api.utils.models import object_id from marble_api.versions.v1.data_request.models import Author, DataRequestUpdate @@ -99,6 +100,11 @@ def fake_class(self, fake): def test_id_dumped(self, fake_class): assert "id" in fake_class().model_dump() + def test_created(self, fake_class): + model = fake_class() + assert model.created + assert object_id(model.id, None).generation_time == model.created + class TestStacItem: def test_valid(self, fake_class): assert Item.from_dict(fake_class().stac_item)