Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions marble_api/versions/v1/data_request/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Field,
FieldSerializationInfo,
ValidationInfo,
computed_field,
field_serializer,
field_validator,
model_validator,
Expand All @@ -30,7 +31,7 @@
collapse_geometries,
validate_collapsible,
)
from marble_api.utils.models import partial_model
from marble_api.utils.models import object_id, partial_model

PyObjectId = Annotated[str, BeforeValidator(str)]
Temporal = Annotated[list[AwareDatetime], Field(..., min_length=1, max_length=2), AfterValidator(sorted)]
Expand All @@ -53,6 +54,7 @@ class DataRequest(BaseModel):

id: SkipJsonSchema[PyObjectId | None] = Field(default=None, validation_alias="_id", exclude=True)
user: SkipJsonSchema[str | None] = None # user is set by the route after the model is first validated
updated: SkipJsonSchema[AwareDatetime | None] = None # updated is set by the route
title: str
description: str | None = None
authors: list[Author]
Expand Down Expand Up @@ -97,7 +99,7 @@ def convert_from_utc(self, value: Temporal, info: FieldSerializationInfo) -> lis

@field_serializer("user")
def require_user_set(self, value: str, info: FieldSerializationInfo) -> str:
"""Require that the user_name is set when the model is serialized."""
"""Require that the user name be set when the model is serialized."""
assert value, f"{info.field_name} must be set and non-empty"
return value

Expand All @@ -123,8 +125,14 @@ class DataRequestPublic(DataRequest):

id: Annotated[str, BeforeValidator(str)] = Field(..., validation_alias="_id")
user: str # user is required to be set in the database
updated: AwareDatetime
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True, extra="allow")

@computed_field
def created(self) -> AwareDatetime:
"""Set the created time based on the object id."""
return object_id(self.id, None).generation_time

@property
def stac_item(self) -> Item:
"""Dynamically create a STAC item representation of this data."""
Expand Down
4 changes: 4 additions & 0 deletions marble_api/versions/v1/data_request/routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from collections.abc import AsyncGenerator
from typing import Annotated

Expand Down Expand Up @@ -43,6 +44,7 @@ def _is_router_scope(request: Request, router: APIRouter) -> bool:
async def post_data_request_user(user: str, data_request: DataRequest) -> DataRequestPublic:
"""Create a new data request and return the newly created data request."""
data_request.user = user
data_request.updated = datetime.datetime.now(tz=datetime.timezone.utc)
new_data_request = data_request.model_dump(by_alias=True)
result = await client.db["data-request"].insert_one(new_data_request)
new_data_request["id"] = str(result.inserted_id)
Expand All @@ -63,6 +65,8 @@ async def patch_data_request(
if user:
data_request.user = user
selector = {"_id": _data_request_id(request_id)}
# updated timestamps are handled automatically
updated_fields["updated"] = datetime.datetime.now(tz=datetime.timezone.utc)
if updated_fields:
result = await client.db["data-request"].find_one_and_update(
selector, {"$set": updated_fields}, return_document=ReturnDocument.AFTER
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"fastapi~=0.115",
"fastapi<0.137", # TODO: handle breaking changes https://github.com/fastapi/fastapi/pull/15745
"pymongo~=4.14",
"geojson-pydantic~=2.0",
"stac-pydantic~=3.4",
Expand Down
3 changes: 2 additions & 1 deletion test/faker_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,14 @@ def asset(self):
"type": self.generator.mime_type(),
"title": self.generator.word(),
"description": self.generator.sentence(),
"roles": self.generator.get_words_list(),
"roles": self.generator.words(nb=self.generator.random.randint(0, 10)),
}

def _data_request_inputs(self, unset=None):
inputs = dict(
id=bson.ObjectId(),
user=self.generator.profile("username")["username"],
updated=self.generator.tz_aware_date_time_seconds_precision(),
title=self.generator.sentence(),
description=(None if self.generator.pybool(30) else self.generator.paragraph()),
authors=[self.author() for _ in range(self.generator.random.randint(1, 10))],
Expand Down
49 changes: 39 additions & 10 deletions test/integration/versions/v1/data_request/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import inspect
import json
from urllib.parse import parse_qs, urlparse
Expand All @@ -13,6 +14,13 @@
pytestmark = pytest.mark.anyio


def compare_no_timestamps(dict1, dict2, message=""):
timestamps = {"created", "updated"}
assert {k: v for k, v in dict1.items() if k not in timestamps} == {
k: v for k, v in dict2.items() if k not in timestamps
}, message


class _TestUser:
@pytest.fixture
def member_route(self, data_requests):
Expand Down Expand Up @@ -62,7 +70,7 @@ class _TestGetOne(_TestGet):
async def test_get(self, async_client, data_requests, member_route):
resp = await async_client.get(member_route)
assert resp.status_code == 200
assert DataRequestPublic(**data_requests[0]) == DataRequestPublic(**resp.json())
assert DataRequestPublic(**data_requests[0]).model_dump() == DataRequestPublic(**resp.json()).model_dump()

async def test_get_stac(self, async_client, member_route):
resp = await async_client.get(f"{member_route}?stac=true")
Expand Down Expand Up @@ -95,9 +103,9 @@ class _TestGetMany(_TestGet):

async def test_get(self, async_client, data_requests, collection_route):
response = await async_client.get(collection_route)
models = {str(req["_id"]): DataRequestPublic(**req) for req in data_requests}
models = {str(req["_id"]): DataRequestPublic(**req).model_dump() for req in data_requests}
for req in response.json()["data_requests"]:
assert DataRequestPublic(**req) == models[req["id"]]
assert DataRequestPublic(**req).model_dump() == models[req["id"]]

async def test_get_stac(self, async_client, collection_route):
resp = await async_client.get(f"{collection_route}?stac=true")
Expand Down Expand Up @@ -229,7 +237,7 @@ async def test_valid(self, fake, async_client, collection_route, data_requests):
response_data = response.json()
assert (id_ := response_data.pop("id", None))
bson.ObjectId(id_) # check that the id is a valid object id
assert {"user": data_requests[0]["user"], **json.loads(data)} == response_data
compare_no_timestamps({"user": data_requests[0]["user"], **json.loads(data)}, response_data)

async def test_invalid_authors(self, fake, async_client, collection_route):
data = json.loads(fake.data_request().model_dump_json())
Expand Down Expand Up @@ -276,7 +284,7 @@ async def test_valid(self, loaded_data, async_client, fake, member_route):
response = await async_client.patch(member_route, json=update)
assert response.status_code == 200
loaded_data.update(update)
assert loaded_data == response.json()
compare_no_timestamps(loaded_data, response.json())

async def test_valid_multiple(self, loaded_data, async_client, fake, member_route):
title = fake.sentence()
Expand All @@ -285,20 +293,20 @@ async def test_valid_multiple(self, loaded_data, async_client, fake, member_rout
response = await async_client.patch(member_route, json=update)
assert response.status_code == 200
loaded_data.update(update)
assert loaded_data == response.json()
compare_no_timestamps(loaded_data, response.json())

async def test_update_nothing(self, loaded_data, async_client, member_route):
response = await async_client.patch(member_route, json={})
assert response.status_code == 200
assert loaded_data == response.json()
compare_no_timestamps(loaded_data, response.json())

async def test_no_id_update(self, loaded_data, async_client, member_route):
update = {"id": str(bson.ObjectId())}
response = await async_client.patch(member_route, json=update)
assert response.status_code == 200
assert response.json()["id"] == loaded_data["id"]
assert response.json()["id"] != update["id"]
assert loaded_data == response.json()
compare_no_timestamps(loaded_data, response.json())

async def test_invalid_unset_value(self, async_client, member_route):
response = await async_client.patch(member_route, json={"title": None})
Expand All @@ -319,6 +327,27 @@ async def test_bad_id(self, async_client, collection_route):
resp = await async_client.patch(f"{collection_route}/id-does-not-exist", json={})
assert resp.status_code == 404, resp.json()

async def test_created_in_response(self, fake, async_client, member_route):
title = fake.sentence()
update = {"title": title}
response = await async_client.patch(member_route, json=update)
assert response.status_code == 200
assert response.json()["created"]

async def test_updated_updated(self, loaded_data, fake, async_client, member_route):
title = fake.sentence()
update = {"title": title}
response = await async_client.patch(member_route, json=update)
assert response.status_code == 200
assert loaded_data["updated"] != response.json()["updated"]

@pytest.mark.parametrize("field", ["created", "updated"])
async def test_no_updatable_timestamps(self, loaded_data, async_client, member_route, field):
new_date = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=10)
response = await async_client.patch(member_route, json={field: new_date.isoformat()})
assert response.status_code == 200
assert response.json()[field] != new_date.isoformat()


class TestPatchUser(_TestPatch, _TestUser):
async def test_update_everything(self, loaded_data, async_client, fake, member_route):
Expand All @@ -327,7 +356,7 @@ async def test_update_everything(self, loaded_data, async_client, fake, member_r
assert response.status_code == 200
update["id"] = loaded_data["id"]
update["user"] = loaded_data["user"]
assert update == response.json()
compare_no_timestamps(update, response.json())

async def test_no_update_user(self, loaded_data, async_client, member_route):
new_user = loaded_data["user"] + "suffix"
Expand All @@ -341,7 +370,7 @@ async def test_update_everything(self, loaded_data, async_client, fake, member_r
response = await async_client.patch(member_route, json=update)
assert response.status_code == 200
update["id"] = loaded_data["id"]
assert update == response.json()
compare_no_timestamps(update, response.json())

async def test_update_user(self, loaded_data, async_client, member_route):
new_user = loaded_data["user"] + "suffix"
Expand Down
6 changes: 6 additions & 0 deletions test/unit/versions/v1/data_request/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pystac import Item

from marble_api.utils.geojson import collapse_geometries
from marble_api.utils.models import object_id
from marble_api.versions.v1.data_request.models import Author, DataRequestUpdate


Expand Down Expand Up @@ -99,6 +100,11 @@ def fake_class(self, fake):
def test_id_dumped(self, fake_class):
assert "id" in fake_class().model_dump()

def test_created(self, fake_class):
model = fake_class()
assert model.created
assert object_id(model.id, None).generation_time == model.created

class TestStacItem:
def test_valid(self, fake_class):
assert Item.from_dict(fake_class().stac_item)
Expand Down
Loading