Spaces:
Build error
Build error
File size: 9,062 Bytes
3e5595b 1e081f1 3e5595b 1e081f1 3e5595b 1e081f1 3e5595b 1e081f1 dc53b3a 3e5595b 1e081f1 3e5595b 1e081f1 dc53b3a 1e081f1 3e5595b 1e081f1 3e5595b 1e081f1 3e5595b 1e081f1 3e5595b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
#ifndef RWKV_H
#define RWKV_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef RWKV_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef RWKV_BUILD
# define RWKV_API __declspec(dllexport)
# else
# define RWKV_API __declspec(dllimport)
# endif
# else
# define RWKV_API __attribute__ ((visibility ("default")))
# endif
#else
# define RWKV_API
#endif
// 'ggmf' in hex.
#define RWKV_FILE_MAGIC 0x67676d66
#define RWKV_FILE_VERSION_0 100
#define RWKV_FILE_VERSION_1 101
#define RWKV_FILE_VERSION_MIN RWKV_FILE_VERSION_0
#define RWKV_FILE_VERSION_MAX RWKV_FILE_VERSION_1
// Default file version is the latest version.
#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX
#ifdef __cplusplus
extern "C" {
#endif
// Represents an error encountered during a function call.
// These are flags, so an actual value might contain multiple errors.
enum rwkv_error_flags {
RWKV_ERROR_NONE = 0,
RWKV_ERROR_ARGS = 1 << 8,
RWKV_ERROR_FILE = 2 << 8,
RWKV_ERROR_MODEL = 3 << 8,
RWKV_ERROR_MODEL_PARAMS = 4 << 8,
RWKV_ERROR_GRAPH = 5 << 8,
RWKV_ERROR_CTX = 6 << 8,
RWKV_ERROR_ALLOC = 1,
RWKV_ERROR_FILE_OPEN = 2,
RWKV_ERROR_FILE_STAT = 3,
RWKV_ERROR_FILE_READ = 4,
RWKV_ERROR_FILE_WRITE = 5,
RWKV_ERROR_FILE_MAGIC = 6,
RWKV_ERROR_FILE_VERSION = 7,
RWKV_ERROR_DATA_TYPE = 8,
RWKV_ERROR_UNSUPPORTED = 9,
RWKV_ERROR_SHAPE = 10,
RWKV_ERROR_DIMENSION = 11,
RWKV_ERROR_KEY = 12,
RWKV_ERROR_DATA = 13,
RWKV_ERROR_PARAM_MISSING = 14
};
// RWKV context that can be used for inference.
// All functions that operate on rwkv_context are thread-safe.
// rwkv_context can be sent to different threads between calls to rwkv_eval.
// There is no requirement for rwkv_context to be freed on the creating thread.
struct rwkv_context;
// Sets whether errors are automatically printed to stderr.
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
// - ctx: the context to suppress error messages for.
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
// as well as the default for new context.
// - print_errors: whether error messages should be automatically printed.
RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors);
// Gets whether errors are automatically printed to stderr.
// - ctx: the context to retrieve the setting for, or NULL for the global setting.
RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx);
// Retrieves and clears the error flags.
// - ctx: the context the retrieve the error for, or NULL for the global error.
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);
// Loads the model from a file and prepares it for inference.
// Returns NULL on any error.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
// Creates a new context from an existing one.
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
// Each rwkv_context can have one eval running at a time.
// Every rwkv_context must be freed using rwkv_free.
// - ctx: context to be cloned.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
// Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS.
// Returns true if at least one layer was offloaded.
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false.
RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers);
// Evaluates the model for a single token.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
// that you do not calculate logits.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval(struct rwkv_context *, const int n_threads, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
// Evaluates the model for a sequence of tokens.
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
// that you do not calculate logits.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
// Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);
// Returns the number of elements in the given model's embedding.
// Useful for reading individual fields of a model's hidden state.
RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx);
// Returns the number of layers in the given model.
// Useful for always offloading the entire model to GPU.
RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx);
// Returns the number of float elements in a complete state for the given model.
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx);
// Returns the number of float elements in the logits output of a given model.
// This is currently always identical to n_vocab.
RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx);
// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state);
// Frees all allocated memory and the context.
// Does not need to be called on the same thread that created the rwkv_context.
RWKV_API void rwkv_free(struct rwkv_context * ctx);
// Quantizes FP32 or FP16 model to one of quantized formats.
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - format_name: must be one of available format names below.
// Available format names:
// - Q4_0
// - Q4_1
// - Q5_0
// - Q5_1
// - Q8_0
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name);
// Returns system information string.
RWKV_API const char * rwkv_get_system_info_string(void);
#ifdef __cplusplus
}
#endif
#endif |