Spaces:
Sleeping
Sleeping
| // meta information about KV cells that can be part of multiple sequences at the same time | |
| // TODO: add unit tests | |
| class llama_kv_cells_unified { | |
| public: | |
| void reset() { | |
| for (uint32_t i = 0; i < pos.size(); ++i) { | |
| pos[i] = -1; | |
| shift[i] = 0; | |
| seq[i].reset(); | |
| } | |
| has_shift = false; | |
| used.clear(); | |
| for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { | |
| seq_pos[s].clear(); | |
| } | |
| } | |
| void reset_shift() { | |
| has_shift = false; | |
| for (uint32_t i = 0; i < shift.size(); ++i) { | |
| shift[i] = 0; | |
| } | |
| } | |
| uint32_t size() const { | |
| return pos.size(); | |
| } | |
| void resize(uint32_t n) { | |
| pos.resize(n); | |
| shift.resize(n); | |
| seq.resize(n); | |
| reset(); | |
| } | |
| bool is_empty(uint32_t i) const { | |
| assert(i < pos.size()); | |
| assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); | |
| return pos[i] == -1; | |
| } | |
| uint32_t get_used() const { | |
| return used.size(); | |
| } | |
| // the index of the first cell that is used | |
| // return 0 if no cells are used | |
| uint32_t used_min() const { | |
| return used.empty() ? 0 : *used.begin(); | |
| } | |
| // the index of the last cell that is used + 1 | |
| // return 0 if no cells are used | |
| uint32_t used_max_p1() const { | |
| return used.empty() ? 0 : *used.rbegin() + 1; | |
| } | |
| bool get_has_shift() const { | |
| return has_shift; | |
| } | |
| // move cell isrc to idst (used during defrag) | |
| void mv(uint32_t isrc, uint32_t idst) { | |
| assert(isrc < pos.size()); | |
| assert(idst < pos.size()); | |
| assert(pos[idst] == -1); | |
| assert(pos[isrc] != -1); | |
| pos [idst] = pos [isrc]; | |
| shift[idst] = shift[isrc]; | |
| seq [idst] = seq [isrc]; | |
| pos [isrc] = -1; | |
| shift[isrc] = 0; | |
| seq [isrc].reset(); | |
| used.erase (isrc); | |
| used.insert(idst); | |
| } | |
| // copy the state of cells [i, i + n) (used for save/restore the state of the cells) | |
| llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { | |
| assert(i + n <= pos.size()); | |
| llama_kv_cells_unified res; | |
| res.resize(n); | |
| for (uint32_t j = 0; j < n; ++j) { | |
| const auto idx = i + j; | |
| res.pos[j] = pos[idx]; | |
| res.seq[j] = seq[idx]; | |
| assert(shift[idx] == 0); | |
| } | |
| return res; | |
| } | |
| // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) | |
| llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const { | |
| llama_kv_cells_unified res; | |
| res.resize(idxs.size()); | |
| for (uint32_t j = 0; j < idxs.size(); ++j) { | |
| const auto idx = idxs[j]; | |
| res.pos[j] = pos[idx]; | |
| res.seq[j] = seq[idx]; | |
| assert(shift[idx] == 0); | |
| } | |
| return res; | |
| } | |
| // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) | |
| void set(uint32_t i, const llama_kv_cells_unified & other) { | |
| assert(i + other.pos.size() <= pos.size()); | |
| for (uint32_t j = 0; j < other.pos.size(); ++j) { | |
| const auto idx = i + j; | |
| if (pos[idx] == -1 && other.pos[j] != -1) { | |
| used.insert(i + j); | |
| } | |
| if (pos[idx] != -1 && other.pos[j] == -1) { | |
| used.erase(i + j); | |
| } | |
| if (pos[idx] != -1) { | |
| seq_pos_rm(i + j); | |
| } | |
| pos[idx] = other.pos[j]; | |
| seq[idx] = other.seq[j]; | |
| if (pos[idx] != -1) { | |
| seq_pos_add(i + j); | |
| } | |
| assert(shift[idx] == 0); | |
| } | |
| } | |
| // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) | |
| void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) { | |
| assert(idxs.size() == other.pos.size()); | |
| for (uint32_t j = 0; j < other.pos.size(); ++j) { | |
| const auto idx = idxs[j]; | |
| if (pos[idx] == -1 && other.pos[j] != -1) { | |
| used.insert(idx); | |
| } | |
| if (pos[idx] != -1 && other.pos[j] == -1) { | |
| used.erase(idx); | |
| } | |
| if (pos[idx] != -1) { | |
| seq_pos_rm(idx); | |
| } | |
| pos[idx] = other.pos[j]; | |
| seq[idx] = other.seq[j]; | |
| if (pos[idx] != -1) { | |
| seq_pos_add(idx); | |
| } | |
| assert(shift[idx] == 0); | |
| } | |
| } | |
| // clear a non-empty cell | |
| void rm(uint32_t i) { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| seq_pos_rm(i); | |
| seq[i].reset(); | |
| pos[i] = -1; | |
| shift[i] = 0; | |
| used.erase(i); | |
| } | |
| // note: call only if the cell has seq_id | |
| // return true if the cell becomes empty | |
| bool seq_rm(uint32_t i, llama_seq_id seq_id) { | |
| assert(i < pos.size()); | |
| assert(seq[i].test(seq_id)); | |
| assert(pos[i] != -1); | |
| assert(seq_id >= 0); | |
| seq[i].reset(seq_id); | |
| seq_pos_dec(seq_id, pos[i]); | |
| if (seq[i].none()) { | |
| pos[i] = -1; | |
| shift[i] = 0; | |
| used.erase(i); | |
| return true; | |
| } | |
| return false; | |
| } | |
| // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) | |
| bool seq_keep(uint32_t i, llama_seq_id seq_id) { | |
| assert(i < pos.size()); | |
| if (seq[i].test(seq_id)) { | |
| seq_pos_rm(i); | |
| seq[i].reset(); | |
| seq[i].set(seq_id); | |
| seq_pos_inc(seq_id, pos[i]); | |
| return false; | |
| } | |
| if (seq[i].any()) { | |
| seq_pos_rm(i); | |
| seq[i].reset(); | |
| pos[i] = -1; | |
| shift[i] = 0; | |
| used.erase(i); | |
| return true; | |
| } | |
| assert(pos[i] == -1); | |
| return false; | |
| } | |
| // number of different sequences in the cell | |
| int seq_count(uint32_t i) const { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| return seq[i].count(); | |
| } | |
| // check if the cell contains seq_id | |
| bool seq_has(uint32_t i, llama_seq_id seq_id) const { | |
| assert(i < pos.size()); | |
| assert(seq_id >= 0); | |
| return seq[i].test(seq_id); | |
| } | |
| // note: call only if the cell is not empty and the seq_id is not in the cell | |
| void seq_add(uint32_t i, llama_seq_id seq_id) { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| assert(!seq[i].test(seq_id)); | |
| seq[i].set(seq_id); | |
| seq_pos_inc(seq_id, pos[i]); | |
| } | |
| // return the sequence id of this cell | |
| // note: call only for cells with exactly one sequence | |
| llama_seq_id seq_get(uint32_t i) const { | |
| assert(seq[i].count() == 1); | |
| for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { | |
| if (seq[i].test(s)) { | |
| return s; | |
| } | |
| } | |
| return -1; | |
| } | |
| // the minimum position of sequence seq_id currently present in any of the cells | |
| // return -1 if the sequence is not present | |
| llama_pos seq_pos_min(llama_seq_id seq_id) const { | |
| assert(seq_id >= 0); | |
| assert(seq_id < LLAMA_MAX_SEQ); | |
| if (seq_pos[seq_id].empty()) { | |
| return -1; | |
| } | |
| assert(seq_pos[seq_id].begin()->second > 0); | |
| return seq_pos[seq_id].begin()->first; | |
| } | |
| // the maximum position of sequence seq_id currently present in any of the cells | |
| // return -1 if the sequence is not present | |
| llama_pos seq_pos_max(llama_seq_id seq_id) const { | |
| assert(seq_id >= 0); | |
| assert(seq_id < LLAMA_MAX_SEQ); | |
| if (seq_pos[seq_id].empty()) { | |
| return -1; | |
| } | |
| assert(seq_pos[seq_id].rbegin()->second > 0); | |
| return seq_pos[seq_id].rbegin()->first; | |
| } | |
| // note: call only if the cell is not empty | |
| llama_pos pos_get(uint32_t i) const { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| return pos[i]; | |
| } | |
| // note: call only if the cell is not empty | |
| llama_pos get_shift(uint32_t i) const { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| return shift[i]; | |
| } | |
| // check if a cell is not empty and its position is within [p0, p1) | |
| bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { | |
| assert(i < pos.size()); | |
| return pos[i] >= p0 && pos[i] < p1; | |
| } | |
| // set the position of an empty cell | |
| // does not modify "has_shift" | |
| // note: call only if the cell is empty | |
| void pos_set(uint32_t i, llama_pos p) { | |
| assert(i < pos.size()); | |
| assert(pos[i] == -1); | |
| assert(seq[i].none()); | |
| pos[i] = p; | |
| used.insert(i); | |
| } | |
| // pos[i] = pos[i] + d | |
| // sets "has_shift" to true | |
| // note: call only if the cell is not empty | |
| bool pos_add(uint32_t i, llama_pos d) { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| seq_pos_rm(i); | |
| pos[i] += d; | |
| shift[i] += d; | |
| has_shift = true; | |
| if (pos[i] < 0) { | |
| seq[i].reset(); | |
| pos[i] = -1; | |
| shift[i] = 0; | |
| used.erase(i); | |
| return true; | |
| } | |
| seq_pos_add(i); | |
| return false; | |
| } | |
| // pos[i] = pos[i] / d | |
| // sets "has_shift" to true | |
| // note: call only if the cell is not empty | |
| void pos_div(uint32_t i, int d) { | |
| assert(i < pos.size()); | |
| assert(pos[i] != -1); | |
| const llama_pos p_old = pos[i]; | |
| seq_pos_rm(i); | |
| pos[i] /= d; | |
| shift[i] += p_old - pos[i]; | |
| seq_pos_add(i); | |
| has_shift = true; | |
| } | |
| private: | |
| bool has_shift = false; | |
| // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) | |
| std::set<uint32_t> used; | |
| std::vector<llama_pos> pos; | |
| // this array accumulates any applied shifts to the pos array since the last reset_shift() call | |
| // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: | |
| // | |
| // cells.pos_add(x, shift_x); | |
| // cells.pos_div(y, shift_y); | |
| // ... | |
| // | |
| // if (cells.has_shift()) { | |
| // for (int i = 0; i < n; ++i) { | |
| // auto shift_i = cells.get_shift(i); | |
| // ... | |
| // } | |
| // cells.reset_shift(); | |
| // } | |
| // | |
| std::vector<llama_pos> shift; | |
| using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; | |
| // the bitset seq[i] tells us which sequences are currently occupying the i-th cell | |
| std::vector<seq_set_t> seq; | |
| // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s | |
| // if the position p is not present, seq_pos[s][p] is not set | |
| // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache | |
| // | |
| // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq: | |
| // - during performing a cache reuse via (rm + add) | |
| // - some vision models have input embeddings with repeating positions | |
| // | |
| std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ]; | |
| // helper functions for updating `seq_pos`, once cell at a time: | |
| void seq_pos_dec(llama_seq_id s, llama_pos p) { | |
| auto it = seq_pos[s].find(p); | |
| assert(it != seq_pos[s].end()); | |
| if (--it->second == 0) { | |
| seq_pos[s].erase(it); | |
| } | |
| } | |
| void seq_pos_inc(llama_seq_id s, llama_pos p) { | |
| seq_pos[s][p]++; | |
| } | |
| // remove cell i | |
| void seq_pos_rm(uint32_t i) { | |
| for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { | |
| if (seq[i].test(s)) { | |
| seq_pos_dec(s, pos[i]); | |
| } | |
| } | |
| } | |
| // add cell i | |
| void seq_pos_add(uint32_t i) { | |
| for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { | |
| if (seq[i].test(s)) { | |
| seq_pos_inc(s, pos[i]); | |
| } | |
| } | |
| } | |
| }; | |