You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							80 lines
						
					
					
						
							2.2 KiB
						
					
					
				
			
		
		
	
	
							80 lines
						
					
					
						
							2.2 KiB
						
					
					
				import torch
 | 
						|
 | 
						|
 | 
						|
def continuous_tensor(
 | 
						|
    inputs: torch.Tensor, seq_length: torch.LongTensor
 | 
						|
):
 | 
						|
    """Convert batched tensor to continuous tensor.
 | 
						|
 | 
						|
    Args:
 | 
						|
        inputs (Tensor): batched tensor.
 | 
						|
        seq_length (Tensor): length of each sequence.
 | 
						|
 | 
						|
    Return:
 | 
						|
        Tensor: continuoused tensor.
 | 
						|
    """
 | 
						|
    assert inputs.dim() > 1
 | 
						|
    if inputs.size(1) == 1:
 | 
						|
        return inputs.reshape(1, -1)
 | 
						|
 | 
						|
    inputs = [inp[:slen] for inp, slen in zip(inputs, seq_length)]
 | 
						|
 | 
						|
    inputs = torch.cat(inputs).unsqueeze(0)
 | 
						|
    return inputs
 | 
						|
 | 
						|
 | 
						|
def batch_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor):
 | 
						|
    """Convert continuoused tensor to batched tensor.
 | 
						|
 | 
						|
    Args:
 | 
						|
        inputs (Tensor): continuoused tensor.
 | 
						|
        seq_length (Tensor): length of each sequence.
 | 
						|
 | 
						|
    Return:
 | 
						|
        Tensor: batched tensor.
 | 
						|
    """
 | 
						|
    from torch.nn.utils.rnn import pad_sequence
 | 
						|
 | 
						|
    end_loc = seq_length.cumsum(0)
 | 
						|
    start_loc = end_loc - seq_length
 | 
						|
 | 
						|
    inputs = [
 | 
						|
        inputs[0, sloc:eloc] for sloc, eloc in zip(start_loc, end_loc)
 | 
						|
    ]
 | 
						|
    inputs = pad_sequence(inputs, batch_first=True)
 | 
						|
    return inputs
 | 
						|
 | 
						|
 | 
						|
def page_cache(
 | 
						|
    paged_cache: torch.Tensor,
 | 
						|
    batched_cache: torch.Tensor,
 | 
						|
    cache_length: torch.Tensor,
 | 
						|
    block_offsets: torch.Tensor,
 | 
						|
    permute_head: bool = True,
 | 
						|
):
 | 
						|
    """Convert batched cache to paged cache.
 | 
						|
 | 
						|
    Args:
 | 
						|
        paged_cache (Tensor): Output paged cache.
 | 
						|
        batched_cache (Tensor): Input batched cache.
 | 
						|
        cache_length (Tensor): length of the cache.
 | 
						|
        block_offsets (Tensor): Offset of each blocks.
 | 
						|
    """
 | 
						|
    assert block_offsets.dim() == 2
 | 
						|
    block_size = paged_cache.size(1)
 | 
						|
    batch_size = batched_cache.size(0)
 | 
						|
    if permute_head:
 | 
						|
        batched_cache = batched_cache.permute(0, 2, 1, 3)
 | 
						|
 | 
						|
    for b_idx in range(batch_size):
 | 
						|
        cache_len = cache_length[b_idx]
 | 
						|
        b_cache = batched_cache[b_idx]
 | 
						|
        block_off = block_offsets[b_idx]
 | 
						|
        block_off_idx = 0
 | 
						|
        for s_start in range(0, cache_len, block_size):
 | 
						|
            s_end = min(s_start + block_size, cache_len)
 | 
						|
            s_len = s_end - s_start
 | 
						|
            b_off = block_off[block_off_idx]
 | 
						|
            paged_cache[b_off, :s_len] = b_cache[s_start:s_end]
 | 
						|
            block_off_idx += 1
 |