Spaces:
Sleeping
Sleeping
| import { HfInference } from "@huggingface/inference"; | |
| export const LLM_CONFIG = { | |
| /* Hugginface config: */ | |
| ollama: false, | |
| huggingface: true, | |
| url: "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct", | |
| chatModel: "meta-llama/Meta-Llama-3-8B-Instruct", | |
| embeddingModel: | |
| "https://api-inference.huggingface.co/models/mixedbread-ai/mxbai-embed-large-v1", | |
| embeddingDimension: 1024, | |
| /* Ollama (local) config: | |
| */ | |
| // ollama: true, | |
| // url: 'http://127.0.0.1:11434', | |
| // chatModel: 'llama3' as const, | |
| // embeddingModel: 'mxbai-embed-large', | |
| // embeddingDimension: 1024, | |
| // embeddingModel: 'llama3', | |
| // embeddingDimension: 4096, | |
| /* Together.ai config: | |
| ollama: false, | |
| url: 'https://api.together.xyz', | |
| chatModel: 'meta-llama/Llama-3-8b-chat-hf', | |
| embeddingModel: 'togethercomputer/m2-bert-80M-8k-retrieval', | |
| embeddingDimension: 768, | |
| */ | |
| /* OpenAI config: | |
| ollama: false, | |
| url: 'https://api.openai.com', | |
| chatModel: 'gpt-3.5-turbo-16k', | |
| embeddingModel: 'text-embedding-ada-002', | |
| embeddingDimension: 1536, | |
| */ | |
| }; | |
| function apiUrl(path: string) { | |
| // OPENAI_API_BASE and OLLAMA_HOST are legacy | |
| const host = | |
| process.env.LLM_API_URL ?? | |
| process.env.OLLAMA_HOST ?? | |
| process.env.OPENAI_API_BASE ?? | |
| LLM_CONFIG.url; | |
| if (host.endsWith("/") && path.startsWith("/")) { | |
| return host + path.slice(1); | |
| } else if (!host.endsWith("/") && !path.startsWith("/")) { | |
| return host + "/" + path; | |
| } else { | |
| return host + path; | |
| } | |
| } | |
| function apiKey() { | |
| return process.env.LLM_API_KEY ?? process.env.OPENAI_API_KEY; | |
| } | |
| const AuthHeaders = (): Record<string, string> => | |
| apiKey() | |
| ? { | |
| Authorization: "Bearer " + apiKey(), | |
| } | |
| : {}; | |
| // Overload for non-streaming | |
| export async function chatCompletion( | |
| body: Omit<CreateChatCompletionRequest, "model"> & { | |
| model?: CreateChatCompletionRequest["model"]; | |
| } & { | |
| stream?: false | null | undefined; | |
| } | |
| ): Promise<{ content: string; retries: number; ms: number }>; | |
| // Overload for streaming | |
| export async function chatCompletion( | |
| body: Omit<CreateChatCompletionRequest, "model"> & { | |
| model?: CreateChatCompletionRequest["model"]; | |
| } & { | |
| stream?: true; | |
| } | |
| ): Promise<{ content: ChatCompletionContent; retries: number; ms: number }>; | |
| export async function chatCompletion( | |
| body: Omit<CreateChatCompletionRequest, "model"> & { | |
| model?: CreateChatCompletionRequest["model"]; | |
| } | |
| ) { | |
| assertApiKey(); | |
| // OLLAMA_MODEL is legacy | |
| body.model = | |
| body.model ?? | |
| process.env.LLM_MODEL ?? | |
| process.env.OLLAMA_MODEL ?? | |
| LLM_CONFIG.chatModel; | |
| const stopWords = body.stop | |
| ? typeof body.stop === "string" | |
| ? [body.stop] | |
| : body.stop | |
| : []; | |
| if (LLM_CONFIG.ollama || LLM_CONFIG.huggingface) stopWords.push("<|eot_id|>"); | |
| const { | |
| result: content, | |
| retries, | |
| ms, | |
| } = await retryWithBackoff(async () => { | |
| const hf = new HfInference(apiKey()); | |
| const model = hf.endpoint(apiUrl("/v1/chat/completions")); | |
| if (body.stream) { | |
| const completion = model.chatCompletionStream({ | |
| ...body, | |
| }); | |
| return new ChatCompletionContent(completion, stopWords); | |
| } else { | |
| const completion = await model.chatCompletion({ | |
| ...body, | |
| }); | |
| const content = completion.choices[0].message?.content; | |
| if (content === undefined) { | |
| throw new Error( | |
| "Unexpected result from OpenAI: " + JSON.stringify(completion) | |
| ); | |
| } | |
| return content; | |
| } | |
| }); | |
| return { | |
| content, | |
| retries, | |
| ms, | |
| }; | |
| } | |
| export async function tryPullOllama(model: string, error: string) { | |
| if (error.includes("try pulling")) { | |
| console.error("Embedding model not found, pulling from Ollama"); | |
| const pullResp = await fetch(apiUrl("/api/pull"), { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify({ name: model }), | |
| }); | |
| console.log("Pull response", await pullResp.text()); | |
| throw { | |
| retry: true, | |
| error: `Dynamically pulled model. Original error: ${error}`, | |
| }; | |
| } | |
| } | |
| export async function fetchEmbeddingBatch(texts: string[]) { | |
| if (LLM_CONFIG.ollama) { | |
| return { | |
| ollama: true as const, | |
| embeddings: await Promise.all( | |
| texts.map(async (t) => (await ollamaFetchEmbedding(t)).embedding) | |
| ), | |
| }; | |
| } | |
| assertApiKey(); | |
| if (LLM_CONFIG.huggingface) { | |
| const result = await fetch(LLM_CONFIG.embeddingModel, { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| "X-Wait-For-Model": "true", | |
| ...AuthHeaders(), | |
| }, | |
| body: JSON.stringify({ | |
| inputs: texts.map((text) => text.replace(/\n/g, " ")), | |
| }), | |
| }); | |
| const embeddings = await result.json(); | |
| return { | |
| ollama: true as const, | |
| embeddings: embeddings, | |
| }; | |
| } | |
| const { | |
| result: json, | |
| retries, | |
| ms, | |
| } = await retryWithBackoff(async () => { | |
| const result = await fetch(apiUrl("/v1/embeddings"), { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| ...AuthHeaders(), | |
| }, | |
| body: JSON.stringify({ | |
| model: LLM_CONFIG.embeddingModel, | |
| input: texts.map((text) => text.replace(/\n/g, " ")), | |
| }), | |
| }); | |
| if (!result.ok) { | |
| throw { | |
| retry: result.status === 429 || result.status >= 500, | |
| error: new Error( | |
| `Embedding failed with code ${result.status}: ${await result.text()}` | |
| ), | |
| }; | |
| } | |
| return (await result.json()) as CreateEmbeddingResponse; | |
| }); | |
| if (json.data.length !== texts.length) { | |
| console.error(json); | |
| throw new Error("Unexpected number of embeddings"); | |
| } | |
| const allembeddings = json.data; | |
| allembeddings.sort((a, b) => a.index - b.index); | |
| return { | |
| ollama: false as const, | |
| embeddings: allembeddings.map(({ embedding }) => embedding), | |
| usage: json.usage?.total_tokens, | |
| retries, | |
| ms, | |
| }; | |
| } | |
| export async function fetchEmbedding(text: string) { | |
| const { embeddings, ...stats } = await fetchEmbeddingBatch([text]); | |
| return { embedding: embeddings[0], ...stats }; | |
| } | |
| export async function fetchModeration(content: string) { | |
| assertApiKey(); | |
| const { result: flagged } = await retryWithBackoff(async () => { | |
| const result = await fetch(apiUrl("/v1/moderations"), { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| ...AuthHeaders(), | |
| }, | |
| body: JSON.stringify({ | |
| input: content, | |
| }), | |
| }); | |
| if (!result.ok) { | |
| throw { | |
| retry: result.status === 429 || result.status >= 500, | |
| error: new Error( | |
| `Embedding failed with code ${result.status}: ${await result.text()}` | |
| ), | |
| }; | |
| } | |
| return (await result.json()) as { results: { flagged: boolean }[] }; | |
| }); | |
| return flagged; | |
| } | |
| export function assertApiKey() { | |
| if (!LLM_CONFIG.ollama && !apiKey()) { | |
| throw new Error( | |
| "\n Missing LLM_API_KEY in environment variables.\n\n" + | |
| (LLM_CONFIG.ollama ? "just" : "npx") + | |
| " convex env set LLM_API_KEY 'your-key'" | |
| ); | |
| } | |
| } | |
| // Retry after this much time, based on the retry number. | |
| const RETRY_BACKOFF = [1000, 10_000, 20_000]; // In ms | |
| const RETRY_JITTER = 100; // In ms | |
| type RetryError = { retry: boolean; error: any }; | |
| export async function retryWithBackoff<T>( | |
| fn: () => Promise<T> | |
| ): Promise<{ retries: number; result: T; ms: number }> { | |
| let i = 0; | |
| for (; i <= RETRY_BACKOFF.length; i++) { | |
| try { | |
| const start = Date.now(); | |
| const result = await fn(); | |
| const ms = Date.now() - start; | |
| return { result, retries: i, ms }; | |
| } catch (e) { | |
| const retryError = e as RetryError; | |
| if (i < RETRY_BACKOFF.length) { | |
| if (retryError.retry) { | |
| console.log( | |
| `Attempt ${i + 1} failed, waiting ${ | |
| RETRY_BACKOFF[i] | |
| }ms to retry...`, | |
| Date.now() | |
| ); | |
| await new Promise((resolve) => | |
| setTimeout(resolve, RETRY_BACKOFF[i] + RETRY_JITTER * Math.random()) | |
| ); | |
| continue; | |
| } | |
| } | |
| if (retryError.error) throw retryError.error; | |
| else throw e; | |
| } | |
| } | |
| throw new Error("Unreachable"); | |
| } | |
| // Lifted from openai's package | |
| export interface LLMMessage { | |
| /** | |
| * The contents of the message. `content` is required for all messages, and may be | |
| * null for assistant messages with function calls. | |
| */ | |
| content: string | null; | |
| /** | |
| * The role of the messages author. One of `system`, `user`, `assistant`, or | |
| * `function`. | |
| */ | |
| role: "system" | "user" | "assistant" | "function"; | |
| /** | |
| * The name of the author of this message. `name` is required if role is | |
| * `function`, and it should be the name of the function whose response is in the | |
| * `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of | |
| * 64 characters. | |
| */ | |
| name?: string; | |
| /** | |
| * The name and arguments of a function that should be called, as generated by the model. | |
| */ | |
| function_call?: { | |
| // The name of the function to call. | |
| name: string; | |
| /** | |
| * The arguments to call the function with, as generated by the model in | |
| * JSON format. Note that the model does not always generate valid JSON, | |
| * and may hallucinate parameters not defined by your function schema. | |
| * Validate the arguments in your code before calling your function. | |
| */ | |
| arguments: string; | |
| }; | |
| } | |
| // Non-streaming chat completion response | |
| interface CreateChatCompletionResponse { | |
| id: string; | |
| object: string; | |
| created: number; | |
| model: string; | |
| choices: { | |
| index?: number; | |
| message?: { | |
| role: "system" | "user" | "assistant"; | |
| content: string; | |
| }; | |
| finish_reason?: string; | |
| }[]; | |
| usage?: { | |
| completion_tokens: number; | |
| prompt_tokens: number; | |
| total_tokens: number; | |
| }; | |
| } | |
| interface CreateEmbeddingResponse { | |
| data: { | |
| index: number; | |
| object: string; | |
| embedding: number[]; | |
| }[]; | |
| model: string; | |
| object: string; | |
| usage: { | |
| prompt_tokens: number; | |
| total_tokens: number; | |
| }; | |
| } | |
| export interface CreateChatCompletionRequest { | |
| /** | |
| * ID of the model to use. | |
| * @type {string} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| model: string; | |
| // | 'gpt-4' | |
| // | 'gpt-4-0613' | |
| // | 'gpt-4-32k' | |
| // | 'gpt-4-32k-0613' | |
| // | 'gpt-3.5-turbo' | |
| // | 'gpt-3.5-turbo-0613' | |
| // | 'gpt-3.5-turbo-16k' // <- our default | |
| // | 'gpt-3.5-turbo-16k-0613'; | |
| /** | |
| * The messages to generate chat completions for, in the chat format: | |
| * https://platform.openai.com/docs/guides/chat/introduction | |
| * @type {Array<ChatCompletionRequestMessage>} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| messages: LLMMessage[]; | |
| /** | |
| * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. | |
| * @type {number} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| temperature?: number | null; | |
| /** | |
| * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or `temperature` but not both. | |
| * @type {number} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| top_p?: number | null; | |
| /** | |
| * How many chat completion choices to generate for each input message. | |
| * @type {number} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| n?: number | null; | |
| /** | |
| * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. | |
| * @type {boolean} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| stream?: boolean | null; | |
| /** | |
| * | |
| * @type {CreateChatCompletionRequestStop} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| stop?: Array<string> | string; | |
| /** | |
| * The maximum number of tokens allowed for the generated answer. By default, | |
| * the number of tokens the model can return will be (4096 - prompt tokens). | |
| * @type {number} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| max_tokens?: number; | |
| /** | |
| * Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
| * whether they appear in the text so far, increasing the model\'s likelihood | |
| * to talk about new topics. See more information about frequency and | |
| * presence penalties: | |
| * https://platform.openai.com/docs/api-reference/parameter-details | |
| * @type {number} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| presence_penalty?: number | null; | |
| /** | |
| * Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
| * their existing frequency in the text so far, decreasing the model\'s | |
| * likelihood to repeat the same line verbatim. See more information about | |
| * presence penalties: | |
| * https://platform.openai.com/docs/api-reference/parameter-details | |
| * @type {number} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| frequency_penalty?: number | null; | |
| /** | |
| * Modify the likelihood of specified tokens appearing in the completion. | |
| * Accepts a json object that maps tokens (specified by their token ID in the | |
| * tokenizer) to an associated bias value from -100 to 100. Mathematically, | |
| * the bias is added to the logits generated by the model prior to sampling. | |
| * The exact effect will vary per model, but values between -1 and 1 should | |
| * decrease or increase likelihood of selection; values like -100 or 100 | |
| * should result in a ban or exclusive selection of the relevant token. | |
| * @type {object} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| logit_bias?: object | null; | |
| /** | |
| * A unique identifier representing your end-user, which can help OpenAI to | |
| * monitor and detect abuse. Learn more: | |
| * https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids | |
| * @type {string} | |
| * @memberof CreateChatCompletionRequest | |
| */ | |
| user?: string; | |
| tools?: { | |
| // The type of the tool. Currently, only function is supported. | |
| type: "function"; | |
| function: { | |
| /** | |
| * The name of the function to be called. Must be a-z, A-Z, 0-9, or | |
| * contain underscores and dashes, with a maximum length of 64. | |
| */ | |
| name: string; | |
| /** | |
| * A description of what the function does, used by the model to choose | |
| * when and how to call the function. | |
| */ | |
| description?: string; | |
| /** | |
| * The parameters the functions accepts, described as a JSON Schema | |
| * object. See the guide[1] for examples, and the JSON Schema reference[2] | |
| * for documentation about the format. | |
| * [1]: https://platform.openai.com/docs/guides/gpt/function-calling | |
| * [2]: https://json-schema.org/understanding-json-schema/ | |
| * To describe a function that accepts no parameters, provide the value | |
| * {"type": "object", "properties": {}}. | |
| */ | |
| parameters: object; | |
| }; | |
| }[]; | |
| /** | |
| * Controls which (if any) function is called by the model. `none` means the | |
| * model will not call a function and instead generates a message. | |
| * `auto` means the model can pick between generating a message or calling a | |
| * function. Specifying a particular function via | |
| * {"type: "function", "function": {"name": "my_function"}} forces the model | |
| * to call that function. | |
| * | |
| * `none` is the default when no functions are present. | |
| * `auto` is the default if functions are present. | |
| */ | |
| tool_choice?: | |
| | "none" // none means the model will not call a function and instead generates a message. | |
| | "auto" // auto means the model can pick between generating a message or calling a function. | |
| // Specifies a tool the model should use. Use to force the model to call | |
| // a specific function. | |
| | { | |
| // The type of the tool. Currently, only function is supported. | |
| type: "function"; | |
| function: { name: string }; | |
| }; | |
| // Replaced by "tools" | |
| // functions?: { | |
| // /** | |
| // * The name of the function to be called. Must be a-z, A-Z, 0-9, or | |
| // * contain underscores and dashes, with a maximum length of 64. | |
| // */ | |
| // name: string; | |
| // /** | |
| // * A description of what the function does, used by the model to choose | |
| // * when and how to call the function. | |
| // */ | |
| // description?: string; | |
| // /** | |
| // * The parameters the functions accepts, described as a JSON Schema | |
| // * object. See the guide[1] for examples, and the JSON Schema reference[2] | |
| // * for documentation about the format. | |
| // * [1]: https://platform.openai.com/docs/guides/gpt/function-calling | |
| // * [2]: https://json-schema.org/understanding-json-schema/ | |
| // * To describe a function that accepts no parameters, provide the value | |
| // * {"type": "object", "properties": {}}. | |
| // */ | |
| // parameters: object; | |
| // }[]; | |
| // /** | |
| // * Controls how the model responds to function calls. "none" means the model | |
| // * does not call a function, and responds to the end-user. "auto" means the | |
| // * model can pick between an end-user or calling a function. Specifying a | |
| // * particular function via {"name":\ "my_function"} forces the model to call | |
| // * that function. | |
| // * - "none" is the default when no functions are present. | |
| // * - "auto" is the default if functions are present. | |
| // */ | |
| // function_call?: 'none' | 'auto' | { name: string }; | |
| /** | |
| * An object specifying the format that the model must output. | |
| * | |
| * Setting to { "type": "json_object" } enables JSON mode, which guarantees | |
| * the message the model generates is valid JSON. | |
| * *Important*: when using JSON mode, you must also instruct the model to | |
| * produce JSON yourself via a system or user message. Without this, the model | |
| * may generate an unending stream of whitespace until the generation reaches | |
| * the token limit, resulting in a long-running and seemingly "stuck" request. | |
| * Also note that the message content may be partially cut off if | |
| * finish_reason="length", which indicates the generation exceeded max_tokens | |
| * or the conversation exceeded the max context length. | |
| */ | |
| response_format?: { type: "text" | "json_object" }; | |
| } | |
| // Checks whether a suffix of s1 is a prefix of s2. For example, | |
| // ('Hello', 'Kira:') -> false | |
| // ('Hello Kira', 'Kira:') -> true | |
| const suffixOverlapsPrefix = (s1: string, s2: string) => { | |
| for (let i = 1; i <= Math.min(s1.length, s2.length); i++) { | |
| const suffix = s1.substring(s1.length - i); | |
| const prefix = s2.substring(0, i); | |
| if (suffix === prefix) { | |
| return true; | |
| } | |
| } | |
| return false; | |
| }; | |
| export class ChatCompletionContent { | |
| private readonly completion: AsyncIterable<ChatCompletionChunk>; | |
| private readonly stopWords: string[]; | |
| constructor( | |
| completion: AsyncIterable<ChatCompletionChunk>, | |
| stopWords: string[] | |
| ) { | |
| this.completion = completion; | |
| this.stopWords = stopWords; | |
| } | |
| async *readInner() { | |
| for await (const chunk of this.completion) { | |
| yield chunk.choices[0].delta.content; | |
| } | |
| } | |
| // stop words in OpenAI api don't always work. | |
| // So we have to truncate on our side. | |
| async *read() { | |
| let lastFragment = ""; | |
| for await (const data of this.readInner()) { | |
| lastFragment += data; | |
| let hasOverlap = false; | |
| for (const stopWord of this.stopWords) { | |
| const idx = lastFragment.indexOf(stopWord); | |
| if (idx >= 0) { | |
| yield lastFragment.substring(0, idx); | |
| return; | |
| } | |
| if (suffixOverlapsPrefix(lastFragment, stopWord)) { | |
| hasOverlap = true; | |
| } | |
| } | |
| if (hasOverlap) continue; | |
| yield lastFragment; | |
| lastFragment = ""; | |
| } | |
| yield lastFragment; | |
| } | |
| async readAll() { | |
| let allContent = ""; | |
| for await (const chunk of this.read()) { | |
| allContent += chunk; | |
| } | |
| return allContent; | |
| } | |
| } | |
| export async function ollamaFetchEmbedding(text: string) { | |
| const { result } = await retryWithBackoff(async () => { | |
| const resp = await fetch(apiUrl("/api/embeddings"), { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify({ model: LLM_CONFIG.embeddingModel, prompt: text }), | |
| }); | |
| if (resp.status === 404) { | |
| const error = await resp.text(); | |
| await tryPullOllama(LLM_CONFIG.embeddingModel, error); | |
| throw new Error(`Failed to fetch embeddings: ${resp.status}`); | |
| } | |
| return (await resp.json()).embedding as number[]; | |
| }); | |
| return { embedding: result }; | |
| } | |