Spaces:
Runtime error
Runtime error
| """Langchain Wrapper around Sambanova embedding APIs.""" | |
| import json | |
| from typing import Dict, Generator, List, Optional | |
| import requests | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.pydantic_v1 import BaseModel | |
| from langchain_core.utils import get_from_dict_or_env, pre_init | |
| class SambaStudioEmbeddings(BaseModel, Embeddings): | |
| """SambaNova embedding models. | |
| To use, you should have the environment variables | |
| ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI`` | |
| ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, | |
| ``SAMBASTUDIO_EMBEDDINGS_API_KEY`` | |
| set with your personal sambastudio variable or pass it as a named parameter | |
| to the constructor. | |
| Example: | |
| .. code-block:: python | |
| from langchain_community.embeddings import SambaStudioEmbeddings | |
| embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, | |
| sambastudio_embeddings_base_uri=base_uri, | |
| sambastudio_embeddings_project_id=project_id, | |
| sambastudio_embeddings_endpoint_id=endpoint_id, | |
| sambastudio_embeddings_api_key=api_key, | |
| batch_size=32) | |
| (or) | |
| embeddings = SambaStudioEmbeddings(batch_size=32) | |
| (or) | |
| # CoE example | |
| embeddings = SambaStudioEmbeddings( | |
| batch_size=1, | |
| model_kwargs={ | |
| 'select_expert':'e5-mistral-7b-instruct' | |
| } | |
| ) | |
| """ | |
| sambastudio_embeddings_base_url: str = '' | |
| """Base url to use""" | |
| sambastudio_embeddings_base_uri: str = '' | |
| """endpoint base uri""" | |
| sambastudio_embeddings_project_id: str = '' | |
| """Project id on sambastudio for model""" | |
| sambastudio_embeddings_endpoint_id: str = '' | |
| """endpoint id on sambastudio for model""" | |
| sambastudio_embeddings_api_key: str = '' | |
| """sambastudio api key""" | |
| model_kwargs: dict = {} | |
| """Key word arguments to pass to the model.""" | |
| batch_size: int = 32 | |
| """Batch size for the embedding models""" | |
| def validate_environment(cls, values: Dict) -> Dict: | |
| """Validate that api key and python package exists in environment.""" | |
| values['sambastudio_embeddings_base_url'] = get_from_dict_or_env( | |
| values, 'sambastudio_embeddings_base_url', 'SAMBASTUDIO_EMBEDDINGS_BASE_URL' | |
| ) | |
| values['sambastudio_embeddings_base_uri'] = get_from_dict_or_env( | |
| values, | |
| 'sambastudio_embeddings_base_uri', | |
| 'SAMBASTUDIO_EMBEDDINGS_BASE_URI', | |
| default='api/predict/generic', | |
| ) | |
| values['sambastudio_embeddings_project_id'] = get_from_dict_or_env( | |
| values, | |
| 'sambastudio_embeddings_project_id', | |
| 'SAMBASTUDIO_EMBEDDINGS_PROJECT_ID', | |
| ) | |
| values['sambastudio_embeddings_endpoint_id'] = get_from_dict_or_env( | |
| values, | |
| 'sambastudio_embeddings_endpoint_id', | |
| 'SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID', | |
| ) | |
| values['sambastudio_embeddings_api_key'] = get_from_dict_or_env( | |
| values, 'sambastudio_embeddings_api_key', 'SAMBASTUDIO_EMBEDDINGS_API_KEY' | |
| ) | |
| return values | |
| def _get_tuning_params(self) -> str: | |
| """ | |
| Get the tuning parameters to use when calling the model | |
| Returns: | |
| The tuning parameters as a JSON string. | |
| """ | |
| if 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: | |
| tuning_params_dict = self.model_kwargs | |
| else: | |
| tuning_params_dict = { | |
| k: {'type': type(v).__name__, 'value': str(v)} for k, v in (self.model_kwargs.items()) | |
| } | |
| tuning_params = json.dumps(tuning_params_dict) | |
| return tuning_params | |
| def _get_full_url(self, path: str) -> str: | |
| """ | |
| Return the full API URL for a given path. | |
| :param str path: the sub-path | |
| :returns: the full API URL for the sub-path | |
| :rtype: str | |
| """ | |
| return f'{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}' # noqa: E501 | |
| def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: | |
| """Generator for creating batches in the embed documents method | |
| Args: | |
| texts (List[str]): list of strings to embed | |
| batch_size (int, optional): batch size to be used for the embedding model. | |
| Will depend on the RDU endpoint used. | |
| Yields: | |
| List[str]: list (batch) of strings of size batch size | |
| """ | |
| for i in range(0, len(texts), batch_size): | |
| yield texts[i : i + batch_size] | |
| def embed_documents(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]: | |
| """Returns a list of embeddings for the given sentences. | |
| Args: | |
| texts (`List[str]`): List of texts to encode | |
| batch_size (`int`): Batch size for the encoding | |
| Returns: | |
| `List[np.ndarray]` or `List[tensor]`: List of embeddings | |
| for the given sentences | |
| """ | |
| if batch_size is None: | |
| batch_size = self.batch_size | |
| http_session = requests.Session() | |
| url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}') | |
| params = json.loads(self._get_tuning_params()) | |
| embeddings = [] | |
| if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri: | |
| for batch in self._iterate_over_batches(texts, batch_size): | |
| data = {'inputs': batch, 'params': params} | |
| response = http_session.post( | |
| url, | |
| headers={'key': self.sambastudio_embeddings_api_key}, | |
| json=data, | |
| ) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Sambanova /complete call failed with status code ' | |
| f'{response.status_code}.\n Details: {response.text}' | |
| ) | |
| try: | |
| embedding = response.json()['data'] | |
| embeddings.extend(embedding) | |
| except KeyError: | |
| raise KeyError( | |
| "'data' not found in endpoint response", | |
| response.json(), | |
| ) | |
| elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: | |
| for batch in self._iterate_over_batches(texts, batch_size): | |
| items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(batch)] | |
| data = {'items': items, 'params': params} | |
| response = http_session.post( | |
| url, | |
| headers={'key': self.sambastudio_embeddings_api_key}, | |
| json=data, | |
| ) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Sambanova /complete call failed with status code ' | |
| f'{response.status_code}.\n Details: {response.text}' | |
| ) | |
| try: | |
| embedding = [item['value'] for item in response.json()['items']] | |
| embeddings.extend(embedding) | |
| except KeyError: | |
| raise KeyError( | |
| "'items' not found in endpoint response", | |
| response.json(), | |
| ) | |
| elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri: | |
| for batch in self._iterate_over_batches(texts, batch_size): | |
| data = {'instances': batch, 'params': params} | |
| response = http_session.post( | |
| url, | |
| headers={'key': self.sambastudio_embeddings_api_key}, | |
| json=data, | |
| ) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Sambanova /complete call failed with status code ' | |
| f'{response.status_code}.\n Details: {response.text}' | |
| ) | |
| try: | |
| if params.get('select_expert'): | |
| embedding = response.json()['predictions'] | |
| else: | |
| embedding = response.json()['predictions'] | |
| embeddings.extend(embedding) | |
| except KeyError: | |
| raise KeyError( | |
| "'predictions' not found in endpoint response", | |
| response.json(), | |
| ) | |
| else: | |
| raise ValueError( | |
| f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501 | |
| ) | |
| return embeddings | |
| def embed_query(self, text: str) -> List[float]: | |
| """Returns a list of embeddings for the given sentences. | |
| Args: | |
| sentences (`List[str]`): List of sentences to encode | |
| Returns: | |
| `List[np.ndarray]` or `List[tensor]`: List of embeddings | |
| for the given sentences | |
| """ | |
| http_session = requests.Session() | |
| url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}') | |
| params = json.loads(self._get_tuning_params()) | |
| if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri: | |
| data = {'inputs': [text], 'params': params} | |
| response = http_session.post( | |
| url, | |
| headers={'key': self.sambastudio_embeddings_api_key}, | |
| json=data, | |
| ) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Sambanova /complete call failed with status code ' | |
| f'{response.status_code}.\n Details: {response.text}' | |
| ) | |
| try: | |
| embedding = response.json()['data'][0] | |
| except KeyError: | |
| raise KeyError( | |
| "'data' not found in endpoint response", | |
| response.json(), | |
| ) | |
| elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: | |
| data = {'items': [{'id': 'item0', 'value': text}], 'params': params} | |
| response = http_session.post( | |
| url, | |
| headers={'key': self.sambastudio_embeddings_api_key}, | |
| json=data, | |
| ) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Sambanova /complete call failed with status code ' | |
| f'{response.status_code}.\n Details: {response.text}' | |
| ) | |
| try: | |
| embedding = response.json()['items'][0]['value'] | |
| except KeyError: | |
| raise KeyError( | |
| "'items' not found in endpoint response", | |
| response.json(), | |
| ) | |
| elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri: | |
| data = {'instances': [text], 'params': params} | |
| response = http_session.post( | |
| url, | |
| headers={'key': self.sambastudio_embeddings_api_key}, | |
| json=data, | |
| ) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Sambanova /complete call failed with status code ' | |
| f'{response.status_code}.\n Details: {response.text}' | |
| ) | |
| try: | |
| if params.get('select_expert'): | |
| embedding = response.json()['predictions'][0] | |
| else: | |
| embedding = response.json()['predictions'][0] | |
| except KeyError: | |
| raise KeyError( | |
| "'predictions' not found in endpoint response", | |
| response.json(), | |
| ) | |
| else: | |
| raise ValueError( | |
| f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501 | |
| ) | |
| return embedding | |