diff --git a/agent/component/switch.py b/agent/component/switch.py index 315b43f9ab..20e41b7a1c 100644 --- a/agent/component/switch.py +++ b/agent/component/switch.py @@ -97,6 +97,9 @@ class Switch(ComponentBase, ABC): self.set_output("_next", self._param.end_cpn_ids) def process_operator(self, input: Any, operator: str, value: Any) -> bool: + if operator in ("contains", "not contains", "start with", "end with"): + input = "" if input is None else str(input) + value = "" if value is None else str(value) if operator == "contains": return True if value.lower() in input.lower() else False elif operator == "not contains": diff --git a/test/unit_test/agent/component/test_switch.py b/test/unit_test/agent/component/test_switch.py index 10dbfeafcf..a6516317ac 100644 --- a/test/unit_test/agent/component/test_switch.py +++ b/test/unit_test/agent/component/test_switch.py @@ -1,3 +1,5 @@ +import pytest + from agent.component.switch import Switch, SwitchParam @@ -57,3 +59,41 @@ def test_switch_non_empty_and_condition_still_matches(): assert cpn.output("_next") == ["case_target"] assert cpn.output("next") == ["case_target"] + + +@pytest.mark.p1 +def test_switch_none_input_contains_falls_through_to_else(): + param = SwitchParam() + param.conditions = [ + { + "logical_operator": "and", + "items": [{"cpn_id": "answer", "operator": "contains", "value": "foo"}], + "to": ["case_target"], + } + ] + param.end_cpn_ids = ["else_target"] + + cpn = _switch(param, {"answer": None}) + cpn._invoke() + + assert cpn.output("_next") == ["else_target"] + assert cpn.output("next") == ["else_target"] + + +@pytest.mark.p1 +def test_switch_none_value_contains_does_not_raise(): + param = SwitchParam() + param.conditions = [ + { + "logical_operator": "and", + "items": [{"cpn_id": "answer", "operator": "contains", "value": None}], + "to": ["case_target"], + } + ] + param.end_cpn_ids = ["else_target"] + + cpn = _switch(param, {"answer": "foobar"}) + cpn._invoke() + + assert cpn.output("_next") == ["case_target"] + assert cpn.output("next") == ["case_target"]