diff --git a/README.md b/README.md index f5b27c6..a461048 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,55 @@ Options: - `--function-invoke-opt TEXT`: Currently we support only `UnstructuredChunking` for functions. +## Testing LLM Gateway + +You can use AI models configured in Salesforce to generate responses while transforming your data. Below is a sample code example: + +``` +from datacustomcode.client import Client, llm_gateway_generate_text_col + + +def main(): + client = Client() + df = client.read_dlo("Input__dll") + # llm_gateway_generate_text_col returns a struct + # {status, response, error_code, error_message} per row, so per-row + # failures don't abort the Spark job. Pick the field you want with []. + df_generated = df.withColumn( + "greeting__c", + llm_gateway_generate_text_col( + "In one sentence, greet {name} from {city}.", + {"name": col("name__c"), "city": col("homecity__c")}, + model_id="sfdc_ai__DefaultGPT4Omni", # An AI model in your org + )["response"], + ) + + dlo_name = "Output_dll" + client.write_to_dlo(dlo_name, df_upper1, write_mode=WriteMode.APPEND) + + greeting = client.llm_gateway_generate_text("In one sentence, generate a greeting message", "sfdc_ai__DefaultGPT52") + +if __name__ == "__main__": + main() +``` + +In order to test this code on your local machine before deploying it to Data Cloud, you must first set up an External Client App that allows access to the Agent API. Follow this guide to create the ECA https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app. You must use `http://localhost:1717/OauthRedirect` as the callback URL. + +Once the ECA is set up, log in to your org using this ECA +``` +sf org login web \ + --alias myorg \ + --instance-url https://{MY_DOMAIN_URL} \ + --client-id {CONSUMER_KEY} \ + --scopes "sfap_api api" +``` + +then you can test your code using `myorg` alias +``` +datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg +``` + + ## Docker usage The SDK provides Docker-based development options that allow you to test your code in an environment that closely resembles Data Cloud's execution environment. diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 85cfa54..76e24b7 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -17,8 +17,11 @@ "AuthType", "Client", "Credentials", + "DefaultSparkLLMGateway", "PrintDataCloudWriter", "QueryAPIDataCloudReader", + "SparkLLMGateway", + "llm_gateway_generate_text_col", ] @@ -44,4 +47,16 @@ def __getattr__(name: str): from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader return QueryAPIDataCloudReader + elif name == "SparkLLMGateway": + from datacustomcode.llm_gateway import SparkLLMGateway + + return SparkLLMGateway + elif name == "DefaultSparkLLMGateway": + from datacustomcode.llm_gateway import DefaultSparkLLMGateway + + return DefaultSparkLLMGateway + elif name == "llm_gateway_generate_text_col": + from datacustomcode.client import llm_gateway_generate_text_col + + return llm_gateway_generate_text_col raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index 9ad95be..a5303a4 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -18,24 +18,87 @@ from typing import ( TYPE_CHECKING, ClassVar, + Dict, Optional, + Union, ) from datacustomcode.config import config from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.llm_gateway_config import spark_llm_gateway_config from datacustomcode.spark.default import DefaultSparkSessionProvider if TYPE_CHECKING: from pathlib import Path - from pyspark.sql import DataFrame as PySparkDataFrame + from pyspark.sql import Column, DataFrame as PySparkDataFrame from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode + from datacustomcode.llm_gateway.spark_base import SparkLLMGateway from datacustomcode.spark.base import BaseSparkSessionProvider +def _build_spark_llm_gateway() -> "SparkLLMGateway": + """Instantiate the SDK-configured :class:`SparkLLMGateway`. + + Raises: + RuntimeError: If no ``spark_llm_gateway_config`` has been loaded. + """ + cfg = spark_llm_gateway_config.spark_llm_gateway_config + if cfg is None: + raise RuntimeError( + "spark_llm_gateway_config is not configured. Add a " + "'spark_llm_gateway_config' section to config.yaml." + ) + return cfg.to_object() + + +def llm_gateway_generate_text_col( + template: str, + values: Union[Dict[str, "Column"], "Column"], + model_id: Optional[str] = None, +) -> "Column": + """Build a Spark Column that runs the LLM Gateway per row. + + The returned Column yields a struct ``{status, response, error_code, + error_message}`` for each row. Use ``[...]`` (or ``getField``) to pick the + field you want, e.g. ``llm_gateway_generate_text_col(...)["response"]``. + Per-row failures populate ``status`` / ``error_code`` / ``error_message`` + so a single bad row does not abort the whole Spark job. + + Example: + + >>> result = llm_gateway_generate_text_col( + ... "In one sentence, greet {name} from {city}.", + ... {"name": col("name__c"), "city": col("homecity__c")}, + ... model_id="sfdc_ai__DefaultGPT4Omni", + ... ) + >>> df.withColumn("greeting__c", result["response"]) + >>> # …or keep the struct around and inspect failures: + >>> df.withColumn("llm", result).select( + ... "llm.status", "llm.response", "llm.error_message" + ... ) + + Args: + template: The prompt template, with ``{field}`` placeholders matching + keys in ``values``. Substitution uses ``str.format``. + values: Either a mapping from placeholder name to Spark ``Column``, or + a single ``Column`` whose value is already a struct. + model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``. + + Returns: + A Spark ``Column`` of ``StructType`` with fields ``status``, + ``response``, ``error_code``, and ``error_message`` (all nullable + strings). On success, ``status == "SUCCESS"`` and ``response`` holds + the generated text; on failure, ``status == "ERROR"`` and the + ``error_*`` fields carry diagnostic detail. + """ + gateway = Client()._get_spark_llm_gateway() + return gateway.llm_gateway_generate_text_col(template, values, model_id=model_id) + + class DataCloudObjectType(Enum): DLO = "dlo" DMO = "dmo" @@ -94,18 +157,21 @@ class Client: finder: Find a file path reader: A custom reader to use for reading Data Cloud objects. writer: A custom writer to use for writing Data Cloud objects. + spark_llm_gateway: Optional custom :class:`SparkLLMGateway`. Example: >>> client = Client() >>> file_path = client.find_file_path("data.csv") >>> dlo = client.read_dlo("my_dlo") >>> client.write_to_dmo("my_dmo", dlo) + >>> answer = client.llm_gateway_generate_text("Generate a greeting message") """ _instance: ClassVar[Optional[Client]] = None _reader: BaseDataCloudReader _writer: BaseDataCloudWriter _file: DefaultFindFilePath + _spark_llm_gateway: Optional[SparkLLMGateway] _data_layer_history: dict[DataCloudObjectType, set[str]] _code_type: str @@ -114,11 +180,13 @@ def __new__( reader: Optional[BaseDataCloudReader] = None, writer: Optional[BaseDataCloudWriter] = None, spark_provider: Optional[BaseSparkSessionProvider] = None, + spark_llm_gateway: Optional[SparkLLMGateway] = None, code_type: str = "script", ) -> Client: if cls._instance is None: cls._instance = super().__new__(cls) + cls._instance._spark_llm_gateway = spark_llm_gateway # Initialize Readers and Writers from config # and/or provided reader and writer if reader is None or writer is None: @@ -225,6 +293,41 @@ def find_file_path(self, file_name: str) -> Path: return self._file.find_file_path(file_name) # type: ignore[no-any-return] + def llm_gateway_generate_text( + self, + prompt: str, + model_id: Optional[str] = None, + ) -> str: + """Issue a one-shot LLM Gateway call. This is the scalar counterpart to + :func:`llm_gateway_generate_text_col`: it runs **once** — not per row. + Use the column helper method instead when you want to fan a prompt out across + every row of a DataFrame. + + Example: + + >>> response = Client().llm_gateway_generate_text( + ... "Generate a greeting message" + ... ) + + Args: + prompt: The literal prompt to send. Plain text — no + ``{field}`` substitution is performed on this string. + model_id: LLM model id to target. Defaults to + ``sfdc_ai__DefaultGPT4Omni`` when ``None``. + + Returns: + The generated text as a plain Python ``str``; empty when the + gateway response carries no generated text. + """ + return self._get_spark_llm_gateway().llm_gateway_generate_text( + prompt, model_id=model_id + ) + + def _get_spark_llm_gateway(self) -> SparkLLMGateway: + if self._spark_llm_gateway is None: + self._spark_llm_gateway = _build_spark_llm_gateway() + return self._spark_llm_gateway + def _validate_data_layer_history_does_not_contain( self, data_cloud_object_type: DataCloudObjectType ) -> None: diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index 8a6c334..25d233c 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -28,3 +28,6 @@ llm_gateway_config: type_config_name: DefaultLLMGateway options: credentials_profile: default + +spark_llm_gateway_config: + type_config_name: DefaultSparkLLMGateway diff --git a/src/datacustomcode/einstein_platform_config.py b/src/datacustomcode/einstein_platform_config.py index 135809d..1be4c4d 100644 --- a/src/datacustomcode/einstein_platform_config.py +++ b/src/datacustomcode/einstein_platform_config.py @@ -15,20 +15,23 @@ from typing import ( ClassVar, + Generic, Optional, Type, - cast, + TypeVar, ) from datacustomcode.common_config import BaseObjectConfig +_T = TypeVar("_T") -class CredentialsObjectConfig(BaseObjectConfig): + +class CredentialsObjectConfig(BaseObjectConfig, Generic[_T]): type_to_create: ClassVar[Type] credentials_profile: Optional[str] = None sf_cli_org: Optional[str] = None - def to_object(self): + def to_object(self) -> _T: """Create an object instance, automatically including credentials in options""" options = self.options.copy() @@ -38,4 +41,5 @@ def to_object(self): options["sf_cli_org"] = self.sf_cli_org type_ = self.type_to_create.subclass_from_config_name(self.type_config_name) - return cast(type_, type_(**options)) + instance: _T = type_(**options) + return instance diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py index 1b4758f..cfd12d5 100644 --- a/src/datacustomcode/einstein_predictions_config.py +++ b/src/datacustomcode/einstein_predictions_config.py @@ -28,7 +28,7 @@ _E = TypeVar("_E", bound=EinsteinPredictions) -class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]): +class EinsteinPredictionsObjectConfig(CredentialsObjectConfig[_E], Generic[_E]): type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] diff --git a/src/datacustomcode/function/feature_types/chunking.py b/src/datacustomcode/function/feature_types/chunking.py index 31a1ccf..994489e 100644 --- a/src/datacustomcode/function/feature_types/chunking.py +++ b/src/datacustomcode/function/feature_types/chunking.py @@ -16,6 +16,7 @@ """ Pydantic models for Search Index Chunking V1 """ + from enum import Enum from typing import ( Dict, diff --git a/src/datacustomcode/llm_gateway/__init__.py b/src/datacustomcode/llm_gateway/__init__.py index ea8b4af..89d8f59 100644 --- a/src/datacustomcode/llm_gateway/__init__.py +++ b/src/datacustomcode/llm_gateway/__init__.py @@ -15,8 +15,14 @@ from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.default import DefaultLLMGateway +from datacustomcode.llm_gateway.errors import LLMGatewayCallError +from datacustomcode.llm_gateway.spark_base import SparkLLMGateway +from datacustomcode.llm_gateway.spark_default import DefaultSparkLLMGateway __all__ = [ "DefaultLLMGateway", + "DefaultSparkLLMGateway", "LLMGateway", + "LLMGatewayCallError", + "SparkLLMGateway", ] diff --git a/src/datacustomcode/proxy/base.py b/src/datacustomcode/llm_gateway/errors.py similarity index 55% rename from src/datacustomcode/proxy/base.py rename to src/datacustomcode/llm_gateway/errors.py index 71cf314..0d2070b 100644 --- a/src/datacustomcode/proxy/base.py +++ b/src/datacustomcode/llm_gateway/errors.py @@ -12,13 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Exceptions raised by LLM Gateway implementations.""" + from __future__ import annotations -from abc import ABC +from typing import Optional -from datacustomcode.mixin import UserExtendableNamedConfigMixin +class LLMGatewayCallError(RuntimeError): + """Raised when an LLM Gateway call returns an error.""" -class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin): - def __init__(self): - pass + def __init__( + self, + message: str, + *, + status: Optional[object] = None, + error_code: Optional[str] = None, + error_message: Optional[str] = None, + ) -> None: + super().__init__(message) + self.status = status + self.error_code = error_code + self.error_message = error_message diff --git a/src/datacustomcode/llm_gateway/spark_base.py b/src/datacustomcode/llm_gateway/spark_base.py new file mode 100644 index 0000000..ab3fb7b --- /dev/null +++ b/src/datacustomcode/llm_gateway/spark_base.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Union, +) + +from datacustomcode.mixin import UserExtendableNamedConfigMixin + +if TYPE_CHECKING: + from pyspark.sql import Column + + +class SparkLLMGateway(ABC, UserExtendableNamedConfigMixin): + CONFIG_NAME: str + + def __init__(self, **kwargs: Any) -> None: + pass + + @abstractmethod + def llm_gateway_generate_text( + self, + prompt: str, + model_id: Optional[str] = None, + ) -> str: + """Issue a one-shot LLM Gateway call and return the generated text.""" + + @abstractmethod + def llm_gateway_generate_text_col( + self, + template: str, + values: Union[Dict[str, "Column"], "Column"], + model_id: Optional[str] = None, + ) -> "Column": + """Build a Spark ``Column`` that invokes the LLM Gateway per row and + yields a struct ``{status, response, error_code, error_message}``. + + Select an individual field, e.g. + ``llm_gateway_generate_text_col(...)["response"]``. Returning a struct + means a single failing row doesn't abort the Spark job. + Failing row leaves the rest of the DataFrame intact — callers can + inspect ``status`` / ``error_code`` per row instead of having the + Spark job abort. + """ diff --git a/src/datacustomcode/llm_gateway/spark_default.py b/src/datacustomcode/llm_gateway/spark_default.py new file mode 100644 index 0000000..922d434 --- /dev/null +++ b/src/datacustomcode/llm_gateway/spark_default.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Union, +) + +from datacustomcode.llm_gateway.spark_base import SparkLLMGateway + +if TYPE_CHECKING: + from pyspark.sql import Column + + from datacustomcode.llm_gateway.base import LLMGateway + from datacustomcode.llm_gateway.types.generate_text_response import ( + GenerateTextResponse, + ) + + +_DEFAULT_LLM_MODEL_ID = "sfdc_ai__DefaultGPT4Omni" + +_STATUS_SUCCESS = "SUCCESS" +_STATUS_ERROR = "ERROR" + + +class DefaultSparkLLMGateway(SparkLLMGateway): + + CONFIG_NAME = "DefaultSparkLLMGateway" + + def __init__( + self, + llm_gateway: Optional["LLMGateway"] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if llm_gateway is None: + llm_gateway = _build_underlying_gateway() + self._llm_gateway: "LLMGateway" = llm_gateway + + def llm_gateway_generate_text( + self, + prompt: str, + model_id: Optional[str] = None, + ) -> str: + return _invoke_llm_gateway(self._llm_gateway, prompt, model_id) + + def llm_gateway_generate_text_col( + self, + template: str, + values: Union[Dict[str, "Column"], "Column"], + model_id: Optional[str] = None, + ) -> "Column": + """Build a per-row UDF that returns a struct ``{status, response, + error_code, error_message}`` so per-row failures do not abort the + Spark job. Callers select the field they want, e.g. + ``llm_gateway_generate_text_col(...)["response"]``. + """ + from pyspark.sql.functions import struct, udf + from pyspark.sql.types import ( + StringType, + StructField, + StructType, + ) + + if isinstance(values, dict): + values_col = struct(*[v.alias(k) for k, v in values.items()]) + else: + values_col = values + + gateway = self._llm_gateway + result_schema = StructType( + [ + StructField("status", StringType(), True), + StructField("response", StringType(), True), + StructField("error_code", StringType(), True), + StructField("error_message", StringType(), True), + ] + ) + + def _generate(values_row: Any) -> Dict[str, Optional[str]]: + if values_row is None: + return { + "status": _STATUS_ERROR, + "response": None, + "error_code": None, + "error_message": "values column was null for this row", + } + subs = ( + values_row.asDict() + if hasattr(values_row, "asDict") + else dict(values_row) + ) + prompt = template.format(**subs) + return _invoke_llm_gateway_as_struct(gateway, prompt, model_id) + + return udf(_generate, result_schema)(values_col) + + +def _build_underlying_gateway() -> "LLMGateway": + from datacustomcode.llm_gateway_config import llm_gateway_config + + cfg = llm_gateway_config.llm_gateway_config + if cfg is None: + raise RuntimeError( + "llm_gateway_config is not configured. Add an 'llm_gateway_config' " + "section to config.yaml." + ) + return cfg.to_object() + + +def _call_llm_gateway( + gateway: "LLMGateway", + prompt: str, + model_id: Optional[str], +) -> "GenerateTextResponse": + """Build the request and dispatch it to the underlying gateway.""" + from datacustomcode.llm_gateway.types.generate_text_request_builder import ( + GenerateTextRequestBuilder, + ) + + request = ( + GenerateTextRequestBuilder() + .set_prompt(prompt) + .set_model(model_id or _DEFAULT_LLM_MODEL_ID) + .build() + ) + return gateway.generate_text(request) + + +def _invoke_llm_gateway( + gateway: "LLMGateway", + prompt: str, + model_id: Optional[str], +) -> str: + from datacustomcode.llm_gateway.errors import LLMGatewayCallError + + response = _call_llm_gateway(gateway, prompt, model_id) + if response.is_error: + raise LLMGatewayCallError( + f"LLM Gateway call failed: status_code={response.status_code}, " + f"error_code={response.error_code!r}, " + f"message={response.data!r}", + status=response.status_code, + error_code=response.error_code or None, + error_message=str(response.data) if response.data else None, + ) + return response.text + + +def _invoke_llm_gateway_as_struct( + gateway: "LLMGateway", + prompt: str, + model_id: Optional[str], +) -> Dict[str, Optional[str]]: + response = _call_llm_gateway(gateway, prompt, model_id) + if response.is_error: + return { + "status": _STATUS_ERROR, + "response": None, + "error_code": response.error_code or None, + "error_message": str(response.data) if response.data else None, + } + return { + "status": _STATUS_SUCCESS, + "response": response.text, + "error_code": None, + "error_message": None, + } diff --git a/src/datacustomcode/llm_gateway_config.py b/src/datacustomcode/llm_gateway_config.py index a65d0eb..3b4ebb5 100644 --- a/src/datacustomcode/llm_gateway_config.py +++ b/src/datacustomcode/llm_gateway_config.py @@ -21,14 +21,20 @@ Union, ) -from datacustomcode.common_config import BaseConfig, default_config_file +from datacustomcode.common_config import ( + BaseConfig, + BaseObjectConfig, + default_config_file, +) from datacustomcode.einstein_platform_config import CredentialsObjectConfig from datacustomcode.llm_gateway.base import LLMGateway +from datacustomcode.llm_gateway.spark_base import SparkLLMGateway _E = TypeVar("_E", bound=LLMGateway) +_S = TypeVar("_S", bound=SparkLLMGateway) -class LLMGatewayObjectConfig(CredentialsObjectConfig, Generic[_E]): +class LLMGatewayObjectConfig(CredentialsObjectConfig[_E], Generic[_E]): type_to_create: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract] @@ -52,6 +58,41 @@ def merge( return self +class SparkLLMGatewayObjectConfig(BaseObjectConfig, Generic[_S]): + type_to_create: ClassVar[Type[SparkLLMGateway]] = SparkLLMGateway # type: ignore[type-abstract] + + def to_object(self) -> SparkLLMGateway: + type_ = self.type_to_create.subclass_from_config_name(self.type_config_name) + return type_(**self.options) + + +class SparkLLMGatewayConfig(BaseConfig): + spark_llm_gateway_config: Union[ + SparkLLMGatewayObjectConfig[SparkLLMGateway], None + ] = None + + def update(self, other: "SparkLLMGatewayConfig") -> "SparkLLMGatewayConfig": + def merge( + config_a: Union[SparkLLMGatewayObjectConfig, None], + config_b: Union[SparkLLMGatewayObjectConfig, None], + ) -> Union[SparkLLMGatewayObjectConfig, None]: + if config_a is not None and config_a.force: + return config_a + if config_b: + return config_b + return config_a + + self.spark_llm_gateway_config = merge( + self.spark_llm_gateway_config, other.spark_llm_gateway_config + ) + return self + + # Global LLM Gateway config instance llm_gateway_config = LLMGatewayConfig() llm_gateway_config.load(default_config_file()) + + +# Global Spark LLM Gateway config instance +spark_llm_gateway_config = SparkLLMGatewayConfig() +spark_llm_gateway_config.load(default_config_file()) diff --git a/src/datacustomcode/proxy/__init__.py b/src/datacustomcode/proxy/__init__.py deleted file mode 100644 index 93988ff..0000000 --- a/src/datacustomcode/proxy/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/datacustomcode/proxy/client/__init__.py b/src/datacustomcode/proxy/client/__init__.py deleted file mode 100644 index 93988ff..0000000 --- a/src/datacustomcode/proxy/client/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/datacustomcode/proxy/client/base.py b/src/datacustomcode/proxy/client/base.py deleted file mode 100644 index 85e304a..0000000 --- a/src/datacustomcode/proxy/client/base.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from abc import abstractmethod - -from datacustomcode.proxy.base import BaseProxyAccessLayer - - -class BaseProxyClient(BaseProxyAccessLayer): - def __init__(self): - pass - - @abstractmethod - def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ... - - @abstractmethod - def llm_gateway_generate_text( - self, template, values, llmModelId: str, maxTokens: int - ): ... diff --git a/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py b/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py index 0a5dbb3..2502cdf 100644 --- a/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py +++ b/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py @@ -7,6 +7,9 @@ - Requires Runtime parameter (for agentic capabilities) - Type-safe with direct field access (no wrappers) - Automatic validation and conversion + +You can use your AI models configured in Salesforce to generate texts. +See README.md for how to test locally before deploying to Data Cloud. """ import logging diff --git a/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py b/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py index df9b780..d7fe2fa 100644 --- a/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py +++ b/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py @@ -12,6 +12,9 @@ Type: Regression Input: Year_Built__c (numeric) Output: Predicted_SalePrice + +You can use your AI models configured in Salesforce to make predictions. +See README.md for how to test locally before deploying to Data Cloud. """ import logging diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index a1cd685..5231174 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -65,13 +65,9 @@ def make_einstein_prediction(runtime: Runtime) -> None: ) -def generate_text(runtime: Runtime): +def generate_text(runtime: Runtime, prompt: str, model: str = "sfdc_ai__DefaultGPT52"): builder = GenerateTextRequestBuilder() - llm_request = ( - builder.set_prompt("Generate 2 dog names") - .set_model("sfdc_ai__DefaultGPT52") - .build() - ) + llm_request = builder.set_prompt(prompt).set_model(model).build() llm_response = runtime.llm_gateway.generate_text(llm_request) logger.info( f"LLM Gateway generate text results - success: [{llm_response.is_success}] " @@ -88,13 +84,16 @@ def function(request: dict, runtime: Runtime) -> dict: current_seq_no = 1 # Start sequence number from 1 """ - You can use your AI models configured in Salesforce - to generate texts or predict an outcome. - First configure an external client app before using these AI APIs - https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app" + You can use your AI models configured in Salesforce to generate texts + or predict an outcome. See README.md for how to test locally before + deploying to Data Cloud. + + Example: + + >>> generated_text = generate_text(runtime, "Generate a greeting message") + ... prediction = make_einstein_prediction(runtime) + """ - # generate_text(runtime) - # make_einstein_prediction(runtime) for item in items: # Item is DocElement as dict diff --git a/src/datacustomcode/templates/script/payload/entrypoint.py b/src/datacustomcode/templates/script/payload/entrypoint.py index 10ba1d7..ea61096 100644 --- a/src/datacustomcode/templates/script/payload/entrypoint.py +++ b/src/datacustomcode/templates/script/payload/entrypoint.py @@ -12,6 +12,34 @@ def main(): # Perform transformations on the DataFrame df_upper1 = df.withColumn("description__c", upper(col("description__c"))) + """ + You can use your AI models configured in Salesforce to generate column + values. See README.md for how to test locally before deploying to Data Cloud. + + Example (the per-row helper returns a struct + ``{status, response, error_code, error_message}`` — pick the field you + want with ``[...]``): + + >>> from datacustomcode.client import llm_gateway_generate_text_col + df_generated = df.withColumn( + ... "greeting__c", + ... llm_gateway_generate_text_col( + ... "In one sentence, greet {name} from {city}.", + ... {"name": col("name__c"), "city": col("homecity__c")}, + ... model_id="sfdc_ai__DefaultGPT4Omni", + ... )["response"], + ... ) + + You can also invoke the LLM with a literal plain text prompt — no + ``{field}`` substitution is performed on this string. + + Example: + + >>> generated_text = client.llm_gateway_generate_text( + ... prompt, model_id + ... ) + """ + # Drop specific columns related to relationships df_upper1 = df_upper1.drop("sfdcorganizationid__c") df_upper1 = df_upper1.drop("kq_id__c") diff --git a/tests/test_client.py b/tests/test_client.py index c2cf46a..dd8041b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,6 +9,7 @@ Client, DataCloudAccessLayerException, DataCloudObjectType, + llm_gateway_generate_text_col, ) from datacustomcode.config import ( AccessLayerObjectConfig, @@ -253,6 +254,88 @@ def test_read_pattern_flow(self, reset_client, mock_spark): assert "source_dmo" in client._data_layer_history[DataCloudObjectType.DMO] +class TestClientLlmGatewayGenerateText: + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_forwards_args_to_spark_llm_gateway(self, mock_build_gateway, reset_client): + mock_spark_gateway = MagicMock() + mock_spark_gateway.llm_gateway_generate_text.return_value = "reply" + mock_build_gateway.return_value = mock_spark_gateway + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + client = Client(reader=reader, writer=writer) + + result = client.llm_gateway_generate_text("ping", model_id="test-model") + + assert result == "reply" + mock_spark_gateway.llm_gateway_generate_text.assert_called_once_with( + "ping", model_id="test-model" + ) + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_gateway_is_built_lazily_and_cached(self, mock_build_gateway, reset_client): + mock_spark_gateway = MagicMock() + mock_spark_gateway.llm_gateway_generate_text.return_value = "ok" + mock_build_gateway.return_value = mock_spark_gateway + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + client = Client(reader=reader, writer=writer) + + mock_build_gateway.assert_not_called() + + client.llm_gateway_generate_text("a") + client.llm_gateway_generate_text("b") + + mock_build_gateway.assert_called_once_with() + assert mock_spark_gateway.llm_gateway_generate_text.call_count == 2 + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_uses_injected_spark_llm_gateway_without_config_lookup( + self, mock_build_gateway, reset_client + ): + injected = MagicMock() + injected.llm_gateway_generate_text.return_value = "from-injected" + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + client = Client(reader=reader, writer=writer, spark_llm_gateway=injected) + + result = client.llm_gateway_generate_text("hello") + + assert result == "from-injected" + injected.llm_gateway_generate_text.assert_called_once_with( + "hello", model_id=None + ) + mock_build_gateway.assert_not_called() + + +class TestLLMGatewayGenerateTextCol: + """The module-level ``llm_gateway_generate_text_col`` is a thin wrapper + that resolves the client-owned :class:`SparkLLMGateway` and delegates. + """ + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_delegates_to_spark_llm_gateway(self, mock_build_gateway): + mock_spark_gateway = MagicMock() + sentinel_col = MagicMock(name="col") + mock_spark_gateway.llm_gateway_generate_text_col.return_value = sentinel_col + mock_build_gateway.return_value = mock_spark_gateway + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + Client(reader=reader, writer=writer) + + values = {"name": MagicMock()} + result = llm_gateway_generate_text_col("Greet {name}", values, model_id="m") + + assert result is sentinel_col + mock_spark_gateway.llm_gateway_generate_text_col.assert_called_once_with( + "Greet {name}", values, model_id="m" + ) + + # Add tests for DefaultSparkSessionProvider class TestDefaultSparkSessionProvider: diff --git a/tests/test_spark_llm_gateway.py b/tests/test_spark_llm_gateway.py new file mode 100644 index 0000000..090ef85 --- /dev/null +++ b/tests/test_spark_llm_gateway.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from datacustomcode.llm_gateway import DefaultSparkLLMGateway, LLMGatewayCallError +from datacustomcode.llm_gateway.spark_default import ( + _STATUS_ERROR, + _STATUS_SUCCESS, + _build_underlying_gateway, + _invoke_llm_gateway, + _invoke_llm_gateway_as_struct, +) +from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse + + +def _success_response(text: str = "ok") -> GenerateTextResponse: + return GenerateTextResponse( + status_code=200, data={"generation": {"generatedText": text}} + ) + + +def _error_response( + status_code: int = 500, error_code: str = "INTERNAL_ERROR" +) -> GenerateTextResponse: + return GenerateTextResponse(status_code=status_code, data={"errorCode": error_code}) + + +class TestDefaultSparkLLMGatewayConstruction: + """Construction wires an underlying ``LLMGateway``.""" + + def test_uses_injected_llm_gateway_when_provided(self): + injected = MagicMock() + gateway = DefaultSparkLLMGateway(llm_gateway=injected) + assert gateway._llm_gateway is injected + + @patch("datacustomcode.llm_gateway.spark_default._build_underlying_gateway") + def test_falls_back_to_config_when_no_gateway_injected(self, mock_build): + config_built = MagicMock() + mock_build.return_value = config_built + + gateway = DefaultSparkLLMGateway() + + mock_build.assert_called_once_with() + assert gateway._llm_gateway is config_built + + +class TestBuildUnderlyingGateway: + """``_build_underlying_gateway`` resolves the config-defined ``LLMGateway``.""" + + def test_returns_object_from_config(self): + with patch( + "datacustomcode.llm_gateway_config.llm_gateway_config" + ) as mock_obj_config: + mock_gateway = MagicMock() + mock_obj_config.llm_gateway_config.to_object.return_value = mock_gateway + + assert _build_underlying_gateway() is mock_gateway + mock_obj_config.llm_gateway_config.to_object.assert_called_once_with() + + def test_raises_when_config_missing(self): + with patch( + "datacustomcode.llm_gateway_config.llm_gateway_config" + ) as mock_obj_config: + mock_obj_config.llm_gateway_config = None + with pytest.raises(RuntimeError, match="llm_gateway_config"): + _build_underlying_gateway() + + +class TestDefaultSparkLLMGatewayGenerateText: + + def test_forwards_prompt_and_model(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("hello back") + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + result = gateway.llm_gateway_generate_text("hello", model_id="m1") + + assert result == "hello back" + sent = mock_inner.generate_text.call_args.args[0] + assert sent.prompt == "hello" + assert sent.model_name == "m1" + + def test_applies_default_model_when_omitted(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("ok") + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + gateway.llm_gateway_generate_text("just a prompt") + + sent = mock_inner.generate_text.call_args.args[0] + assert sent.model_name == "sfdc_ai__DefaultGPT4Omni" + + +class TestDefaultSparkLLMGatewayGenerateTextCol: + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_dict_values_built_into_struct_and_wrapped_in_udf( + self, mock_struct, mock_udf + ): + sentinel_struct_col = MagicMock(name="struct_col") + mock_struct.return_value = sentinel_struct_col + sentinel_udf = MagicMock(name="udf") + sentinel_applied = MagicMock(name="udf_applied") + sentinel_udf.return_value = sentinel_applied + mock_udf.return_value = sentinel_udf + + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("row-out") + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + name_col, city_col = MagicMock(name="name_col"), MagicMock(name="city_col") + name_aliased, city_aliased = ( + MagicMock(name="name_aliased"), + MagicMock(name="city_aliased"), + ) + name_col.alias.return_value = name_aliased + city_col.alias.return_value = city_aliased + + result = gateway.llm_gateway_generate_text_col( + "Greet {name} from {city}.", + {"name": name_col, "city": city_col}, + model_id="test-model", + ) + + name_col.alias.assert_called_once_with("name") + city_col.alias.assert_called_once_with("city") + mock_struct.assert_called_once_with(name_aliased, city_aliased) + mock_udf.assert_called_once() + sentinel_udf.assert_called_once_with(sentinel_struct_col) + assert result is sentinel_applied + + udf_fn = mock_udf.call_args.args[0] + row = MagicMock() + row.asDict.return_value = {"name": "Ada", "city": "London"} + out = udf_fn(row) + + assert out == { + "status": _STATUS_SUCCESS, + "response": "row-out", + "error_code": None, + "error_message": None, + } + sent = mock_inner.generate_text.call_args.args[0] + assert sent.prompt == "Greet Ada from London." + assert sent.model_name == "test-model" + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_column_values_passed_through_without_struct(self, mock_struct, mock_udf): + from pyspark.sql import Column + + existing_col = MagicMock(spec=Column) + sentinel_udf = MagicMock(name="udf") + sentinel_udf.return_value = MagicMock(name="udf_applied") + mock_udf.return_value = sentinel_udf + + gateway = DefaultSparkLLMGateway(llm_gateway=MagicMock()) + + gateway.llm_gateway_generate_text_col("Greet {name}", existing_col) + + mock_struct.assert_not_called() + sentinel_udf.assert_called_once_with(existing_col) + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_udf_returns_error_struct_for_null_row(self, mock_struct, mock_udf): + mock_struct.return_value = MagicMock() + mock_udf.return_value = MagicMock() + mock_inner = MagicMock() + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + gateway.llm_gateway_generate_text_col("template", {"placeholder": MagicMock()}) + + udf_fn = mock_udf.call_args.args[0] + out = udf_fn(None) + assert out["status"] == _STATUS_ERROR + assert out["response"] is None + assert "null" in out["error_message"].lower() + mock_inner.generate_text.assert_not_called() + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_udf_returns_error_struct_on_http_error(self, mock_struct, mock_udf): + """Per-row HTTP errors are returned as ``status="ERROR"`` structs so + one bad row does not abort the Spark job.""" + mock_struct.return_value = MagicMock() + mock_udf.return_value = MagicMock() + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _error_response( + status_code=503, error_code="UNAVAILABLE" + ) + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + gateway.llm_gateway_generate_text_col("Greet {name}", {"name": MagicMock()}) + + udf_fn = mock_udf.call_args.args[0] + row = MagicMock() + row.asDict.return_value = {"name": "Ada"} + out = udf_fn(row) + + assert out["status"] == _STATUS_ERROR + assert out["response"] is None + assert out["error_code"] == "UNAVAILABLE" + assert out["error_message"] is not None + + +class TestInvokeLLMGateway: + + def test_returns_response_text(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("done") + + assert _invoke_llm_gateway(mock_inner, "prompt", "model") == "done" + sent = mock_inner.generate_text.call_args.args[0] + assert sent.prompt == "prompt" + assert sent.model_name == "model" + + def test_uses_default_model_when_none(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("ok") + + _invoke_llm_gateway(mock_inner, "prompt", None) + sent = mock_inner.generate_text.call_args.args[0] + assert sent.model_name == "sfdc_ai__DefaultGPT4Omni" + + def test_raises_llm_gateway_call_error_on_error_response(self): + """``is_error`` responses surface as ``LLMGatewayCallError`` with the + status code and error code attached for programmatic inspection.""" + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _error_response( + status_code=503, error_code="UNAVAILABLE" + ) + + with pytest.raises(LLMGatewayCallError) as excinfo: + _invoke_llm_gateway(mock_inner, "prompt", "model") + + assert excinfo.value.status == 503 + assert excinfo.value.error_code == "UNAVAILABLE" + assert "503" in str(excinfo.value) + assert "UNAVAILABLE" in str(excinfo.value) + + +class TestInvokeLLMGatewayAsStruct: + """Non-raising variant of ``_invoke_llm_gateway`` used by the per-row UDF. + Both SUCCESS and ERROR cases land in the same struct shape so callers can + select fields uniformly.""" + + def test_success_returns_success_struct(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("howdy") + + out = _invoke_llm_gateway_as_struct(mock_inner, "prompt", "model") + + assert out == { + "status": _STATUS_SUCCESS, + "response": "howdy", + "error_code": None, + "error_message": None, + } + + def test_error_returns_error_struct_without_raising(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _error_response( + status_code=503, error_code="UNAVAILABLE" + ) + + out = _invoke_llm_gateway_as_struct(mock_inner, "prompt", "model") + + assert out["status"] == _STATUS_ERROR + assert out["response"] is None + assert out["error_code"] == "UNAVAILABLE" + assert out["error_message"] is not None + + def test_uses_default_model_when_none(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("ok") + + _invoke_llm_gateway_as_struct(mock_inner, "prompt", None) + sent = mock_inner.generate_text.call_args.args[0] + assert sent.model_name == "sfdc_ai__DefaultGPT4Omni" + + +class TestDefaultSparkLLMGatewayGenerateTextErrorHandling: + """The scalar generate_text path raises when the underlying gateway errors.""" + + def test_raises_on_error_response(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _error_response( + status_code=429, error_code="RATE_LIMITED" + ) + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + with pytest.raises(LLMGatewayCallError) as excinfo: + gateway.llm_gateway_generate_text("hello") + + assert excinfo.value.status == 429 + assert excinfo.value.error_code == "RATE_LIMITED"