From 35b724a0f98555f1f83945173871a19612e2c204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E8=80=80=E5=A5=BD?= <2517926599@qq.com> Date: Wed, 24 Jun 2026 23:13:07 +0800 Subject: [PATCH] fix(schema): give Task.open_query a proper default and add schema tests - Replace the FIXME comment on Task.open_query with an explicit default=False so callers can construct a Task without specifying it. - Add unit tests for Task and TaskObjective covering defaults, overrides, optional fields, and serialization roundtrip. --- agentevolver/schema/task.py | 2 +- tests/test_schema.py | 101 ++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tests/test_schema.py diff --git a/agentevolver/schema/task.py b/agentevolver/schema/task.py index 5418da28..b30dbfbb 100644 --- a/agentevolver/schema/task.py +++ b/agentevolver/schema/task.py @@ -11,7 +11,7 @@ class Task(BaseModel): env_type: str = Field(default="appworld") # whether this task is open query. open query has no clear stop condition. - open_query: bool = Field() # FIXME debug, check if every instance handles this new attr. default False. + open_query: bool = Field(default=False) metadata: dict = Field(default_factory=dict) diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 00000000..fb750b91 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,101 @@ +import sys +from pathlib import Path + +import pytest + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.insert(0, str(ROOT_DIR)) + +from agentevolver.schema.task import Task, TaskObjective + + +class TestTask: + """Unit tests for the Task schema.""" + + def test_task_minimal_construction(self): + task = Task(task_id="task-1") + assert task.task_id == "task-1" + assert task.env_type == "appworld" + + def test_open_query_defaults_to_false(self): + """open_query should default to False so callers do not need to set it.""" + task = Task(task_id="task-1") + assert task.open_query is False + + def test_open_query_can_be_overridden(self): + task = Task(task_id="task-2", open_query=True) + assert task.open_query is True + + def test_ground_truth_is_optional(self): + task_with_gt = Task( + task_id="task-3", + ground_truth="4", + ) + assert task_with_gt.ground_truth == "4" + + task_without_gt = Task(task_id="task-4") + assert task_without_gt.ground_truth is None + + def test_query_defaults_to_none(self): + task = Task(task_id="task-5") + assert task.query is None + + +class TestTaskObjective: + """Unit tests for the TaskObjective wrapper.""" + + def test_task_objective_proxies_ground_truth(self): + task = Task( + task_id="task-6", + ground_truth="Roses are red.", + open_query=True, + ) + objective = TaskObjective(task=task) + + assert objective.ground_truth == "Roses are red." + assert objective.task.open_query is True + + def test_task_objective_exposes_query_as_objective(self): + task = Task( + task_id="task-7", + query="Please summarize.", + ) + objective = TaskObjective(task=task) + assert objective.objective == "Please summarize." + + def test_task_objective_objective_without_query(self): + task = Task( + task_id="task-8", + ) + objective = TaskObjective(task=task) + assert objective.objective is None + + def test_task_objective_ground_truth_setter(self): + task = Task(task_id="task-9") + objective = TaskObjective(task=task) + + objective.ground_truth = "Updated ground truth." + assert task.ground_truth == "Updated ground truth." + assert objective.ground_truth == "Updated ground truth." + + def test_task_objective_dict_roundtrip(self): + task = Task( + task_id="task-10", + ground_truth="Gravity is a force.", + open_query=True, + ) + objective = TaskObjective(task=task, confidence=0.9, reward=1.0) + + data = objective.model_dump() + restored = TaskObjective(**data) + + assert restored.task.task_id == objective.task.task_id + assert restored.task.ground_truth == objective.task.ground_truth + assert restored.task.open_query == objective.task.open_query + assert restored.confidence == 0.9 + assert restored.reward == 1.0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])