Spaces:
Build error
Build error
// 'ggmf' in hex. | |
// Default file version is the latest version. | |
extern "C" { | |
// Represents an error encountered during a function call. | |
// These are flags, so an actual value might contain multiple errors. | |
enum rwkv_error_flags { | |
RWKV_ERROR_NONE = 0, | |
RWKV_ERROR_ARGS = 1 << 8, | |
RWKV_ERROR_FILE = 2 << 8, | |
RWKV_ERROR_MODEL = 3 << 8, | |
RWKV_ERROR_MODEL_PARAMS = 4 << 8, | |
RWKV_ERROR_GRAPH = 5 << 8, | |
RWKV_ERROR_CTX = 6 << 8, | |
RWKV_ERROR_ALLOC = 1, | |
RWKV_ERROR_FILE_OPEN = 2, | |
RWKV_ERROR_FILE_STAT = 3, | |
RWKV_ERROR_FILE_READ = 4, | |
RWKV_ERROR_FILE_WRITE = 5, | |
RWKV_ERROR_FILE_MAGIC = 6, | |
RWKV_ERROR_FILE_VERSION = 7, | |
RWKV_ERROR_DATA_TYPE = 8, | |
RWKV_ERROR_UNSUPPORTED = 9, | |
RWKV_ERROR_SHAPE = 10, | |
RWKV_ERROR_DIMENSION = 11, | |
RWKV_ERROR_KEY = 12, | |
RWKV_ERROR_DATA = 13, | |
RWKV_ERROR_PARAM_MISSING = 14 | |
}; | |
// RWKV context that can be used for inference. | |
// All functions that operate on rwkv_context are thread-safe. | |
// rwkv_context can be sent to different threads between calls to rwkv_eval. | |
// There is no requirement for rwkv_context to be freed on the creating thread. | |
struct rwkv_context; | |
// Sets whether errors are automatically printed to stderr. | |
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails. | |
// - ctx: the context to suppress error messages for. | |
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors, | |
// as well as the default for new context. | |
// - print_errors: whether error messages should be automatically printed. | |
RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors); | |
// Gets whether errors are automatically printed to stderr. | |
// - ctx: the context to retrieve the setting for, or NULL for the global setting. | |
RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx); | |
// Retrieves and clears the error flags. | |
// - ctx: the context the retrieve the error for, or NULL for the global error. | |
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); | |
// Loads the model from a file and prepares it for inference. | |
// Returns NULL on any error. | |
// - model_file_path: path to model file in ggml format. | |
// - n_threads: count of threads to use, must be positive. | |
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); | |
// Creates a new context from an existing one. | |
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times. | |
// Each rwkv_context can have one eval running at a time. | |
// Every rwkv_context must be freed using rwkv_free. | |
// - ctx: context to be cloned. | |
// - n_threads: count of threads to use, must be positive. | |
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads); | |
// Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS. | |
// Returns true if at least one layer was offloaded. | |
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false. | |
RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers); | |
// Evaluates the model for a single token. | |
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. | |
// Returns false on any error. | |
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration | |
// that you do not calculate logits. | |
// - token: next token index, in range 0 <= token < n_vocab. | |
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. | |
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. | |
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. | |
RWKV_API bool rwkv_eval(struct rwkv_context *, const int n_threads, const uint32_t token, const float * state_in, float * state_out, float * logits_out); | |
// Evaluates the model for a sequence of tokens. | |
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so. | |
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. | |
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. | |
// Returns false on any error. | |
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration | |
// that you do not calculate logits. | |
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. | |
// - sequence_len: number of tokens to read from the array. | |
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. | |
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. | |
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. | |
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out); | |
// Returns the number of tokens in the given model's vocabulary. | |
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). | |
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx); | |
// Returns the number of elements in the given model's embedding. | |
// Useful for reading individual fields of a model's hidden state. | |
RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx); | |
// Returns the number of layers in the given model. | |
// Useful for always offloading the entire model to GPU. | |
RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx); | |
// Returns the number of float elements in a complete state for the given model. | |
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state. | |
RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx); | |
// Returns the number of float elements in the logits output of a given model. | |
// This is currently always identical to n_vocab. | |
RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx); | |
// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL. | |
// Useful in cases where tracking the first call to these functions may be annoying or expensive. | |
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs. | |
// - state: FP32 buffer of size rwkv_get_state_len() to initialize | |
RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state); | |
// Frees all allocated memory and the context. | |
// Does not need to be called on the same thread that created the rwkv_context. | |
RWKV_API void rwkv_free(struct rwkv_context * ctx); | |
// Quantizes FP32 or FP16 model to one of quantized formats. | |
// Returns false on any error. Error messages would be printed to stderr. | |
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. | |
// - model_file_path_out: quantized model will be written here. | |
// - format_name: must be one of available format names below. | |
// Available format names: | |
// - Q4_0 | |
// - Q4_1 | |
// - Q5_0 | |
// - Q5_1 | |
// - Q8_0 | |
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name); | |
// Returns system information string. | |
RWKV_API const char * rwkv_get_system_info_string(void); | |
} | |