|
|
import torch |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
from hf_models.opt.modeling_opt import OPTForCausalLM |
|
|
from hf_models.llama.modeling_llama import LlamaForCausalLM |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import build_data_loader |
|
|
from HybridTensor.utils.utils import extract_model_name |
|
|
|
|
|
from datasets import load_dataset |
|
|
import json |
|
|
|
|
|
from tqdm import tqdm |
|
|
import argparse |
|
|
from HybridTensor.utils.activations import MODELS |
|
|
|
|
|
|
|
|
def load_layer_data(data_dir, layer_idx, data_type): |
|
|
""" |
|
|
Load data for a specific layer and data type. |
|
|
|
|
|
Args: |
|
|
data_dir (str): Directory where data is stored. |
|
|
layer_idx (int): Layer index. |
|
|
data_type (str): One of 'hidden_states', 'mlp_activations', 'attn_norms'. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: The data array of shape (num_samples, feature_size). |
|
|
""" |
|
|
|
|
|
metadata_filename = os.path.join(data_dir, 'metadata.json') |
|
|
with open(metadata_filename, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
num_layers = metadata['num_layers'] |
|
|
hidden_size = metadata['hidden_size'] |
|
|
num_heads = metadata['num_heads'] |
|
|
max_samples = metadata['max_samples'] |
|
|
|
|
|
|
|
|
if layer_idx < 0 or layer_idx >= num_layers: |
|
|
raise ValueError(f"Invalid layer_idx: {layer_idx}. Must be between 0 and {num_layers - 1}.") |
|
|
|
|
|
|
|
|
if data_type == 'hidden_states': |
|
|
sample_counts = metadata['hidden_states_counters'] |
|
|
sample_count = sample_counts[layer_idx] |
|
|
feature_size = hidden_size |
|
|
elif data_type == 'mlp_activations': |
|
|
sample_counts = metadata['mlp_activations_counters'] |
|
|
sample_count = sample_counts[layer_idx] |
|
|
feature_size = hidden_size * 4 |
|
|
elif data_type == 'attn_norms': |
|
|
sample_counts = metadata['attn_norms_counters'] |
|
|
sample_count = sample_counts[layer_idx] |
|
|
feature_size = num_heads |
|
|
else: |
|
|
raise ValueError(f"Invalid data_type: {data_type}. Must be 'hidden_states', 'mlp_activations', or 'attn_norms'.") |
|
|
|
|
|
|
|
|
filename = os.path.join(data_dir, f'{data_type}_layer_{layer_idx}.dat') |
|
|
data_mmap = np.memmap(filename, dtype='float16', mode='r', shape=(max_samples, feature_size)) |
|
|
|
|
|
|
|
|
data = np.array(data_mmap[:sample_count]) |
|
|
del data_mmap |
|
|
return data |
|
|
|
|
|
def initialize_data_structures(data_dir, num_layers, hidden_size, num_heads, num_neurons, max_samples, |
|
|
mlp_activation=True, attn_norm=True): |
|
|
""" |
|
|
Initialize mmap files and counters for hidden_states, mlp_activations, and attn_norms. |
|
|
|
|
|
Args: |
|
|
data_dir (str): Directory where data is stored. |
|
|
num_layers (int): Number of transformer layers. |
|
|
hidden_size (int): Hidden size of the model. |
|
|
num_heads (int): Number of attention heads. |
|
|
max_samples (int): Maximum number of samples to collect. |
|
|
|
|
|
Returns: |
|
|
tuple: Contains lists of mmap files and counters for each data type. |
|
|
""" |
|
|
|
|
|
hidden_states_files = [] |
|
|
mlp_activations_files = [] |
|
|
attn_norms_files = [] |
|
|
|
|
|
hidden_states_counters = [] |
|
|
mlp_activations_counters = [] |
|
|
attn_norms_counters = [] |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
|
|
|
hs_filename = os.path.join(data_dir, f'hidden_states_layer_{layer_idx}.dat') |
|
|
hs_file = np.memmap(hs_filename, dtype='float16', mode='w+', shape=(max_samples, hidden_size)) |
|
|
hidden_states_files.append(hs_file) |
|
|
hidden_states_counters.append(0) |
|
|
|
|
|
|
|
|
if mlp_activation: |
|
|
mlp_filename = os.path.join(data_dir, f'mlp_activations_layer_{layer_idx}.dat') |
|
|
mlp_file = np.memmap(mlp_filename, dtype='float16', mode='w+', shape=(max_samples, num_neurons)) |
|
|
mlp_activations_files.append(mlp_file) |
|
|
mlp_activations_counters.append(0) |
|
|
|
|
|
|
|
|
if attn_norm: |
|
|
attn_filename = os.path.join(data_dir, f'attn_norms_layer_{layer_idx}.dat') |
|
|
attn_file = np.memmap(attn_filename, dtype='float16', mode='w+', shape=(max_samples, num_heads)) |
|
|
attn_norms_files.append(attn_file) |
|
|
attn_norms_counters.append(0) |
|
|
|
|
|
return ( |
|
|
hidden_states_files, |
|
|
hidden_states_counters, |
|
|
mlp_activations_files, |
|
|
mlp_activations_counters, |
|
|
attn_norms_files, |
|
|
attn_norms_counters |
|
|
) |
|
|
|
|
|
def process_hidden_states(layer_idx, hidden_states_layer, valid_token_indices, hidden_size, hidden_states_files, hidden_states_counters): |
|
|
""" |
|
|
Process and store hidden states for a specific layer. |
|
|
""" |
|
|
hs = hidden_states_layer.view(-1, hidden_size) |
|
|
hs_valid = hs[valid_token_indices.cpu()] |
|
|
hs_counter = hidden_states_counters[layer_idx] |
|
|
hs_file = hidden_states_files[layer_idx] |
|
|
hs_file[hs_counter:hs_counter+hs_valid.shape[0], :] = hs_valid.cpu().numpy().astype('float16') |
|
|
hidden_states_counters[layer_idx] += hs_valid.shape[0] |
|
|
|
|
|
def process_mlp_activations(layer_idx, mlp_activations_layer, valid_token_indices, hidden_size, mlp_activations_files, mlp_activations_counters): |
|
|
""" |
|
|
Process and store MLP activations for a specific layer. |
|
|
""" |
|
|
neuron_shape = mlp_activations_layer.shape[-1] |
|
|
mlp_activations_layer = mlp_activations_layer.view(-1, neuron_shape) |
|
|
|
|
|
mlp_valid = mlp_activations_layer[valid_token_indices.cpu()] |
|
|
mlp_counter = mlp_activations_counters[layer_idx] |
|
|
mlp_file = mlp_activations_files[layer_idx] |
|
|
mlp_file[mlp_counter:mlp_counter+mlp_valid.shape[0], :] = mlp_valid.cpu().numpy().astype('float16') |
|
|
mlp_activations_counters[layer_idx] += mlp_valid.shape[0] |
|
|
|
|
|
def process_attn_norms(layer_idx, attn_outputs_layer, valid_token_indices, num_heads, attn_norms_files, attn_norms_counters): |
|
|
""" |
|
|
Process and store attention norms for a specific layer. |
|
|
""" |
|
|
|
|
|
attn = attn_outputs_layer |
|
|
attn_norms = torch.norm(attn, dim=-1) |
|
|
attn_norms = attn_norms.view(-1, num_heads) |
|
|
attn_valid = attn_norms[valid_token_indices.cpu()] |
|
|
attn_counter = attn_norms_counters[layer_idx] |
|
|
attn_file = attn_norms_files[layer_idx] |
|
|
attn_file[attn_counter:attn_counter+attn_valid.shape[0], :] = attn_valid.cpu().numpy().astype('float16') |
|
|
attn_norms_counters[layer_idx] += attn_valid.shape[0] |
|
|
|
|
|
def process_batch( |
|
|
outputs, |
|
|
input_ids, |
|
|
attention_mask, |
|
|
total_samples, |
|
|
max_samples, |
|
|
num_layers, |
|
|
hidden_size, |
|
|
num_heads, |
|
|
hidden_states_files, |
|
|
hidden_states_counters, |
|
|
mlp_activations_files, |
|
|
mlp_activations_counters, |
|
|
attn_norms_files, |
|
|
attn_norms_counters, |
|
|
args |
|
|
): |
|
|
""" |
|
|
Process a batch of model outputs and update the data files. |
|
|
|
|
|
Returns: |
|
|
total_samples (int): Updated total number of samples processed. |
|
|
reached_max_samples (bool): Indicates if the maximum number of samples has been reached. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs['router_inputs'] |
|
|
|
|
|
if args.mlp_activation: |
|
|
mlp_activations = outputs['mlp_activations'] |
|
|
else: |
|
|
mlp_activations = None |
|
|
|
|
|
if args.attn_norm: |
|
|
attn_outputs = outputs['attn_outputs'] |
|
|
else: |
|
|
attn_outputs = None |
|
|
|
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
attention_mask_flat = attention_mask.view(-1).bool() |
|
|
num_valid_tokens = attention_mask_flat.sum().item() |
|
|
|
|
|
|
|
|
if total_samples + num_valid_tokens >= max_samples: |
|
|
tokens_to_process = max_samples - total_samples |
|
|
total_samples = max_samples |
|
|
reached_max_samples = True |
|
|
else: |
|
|
tokens_to_process = num_valid_tokens |
|
|
total_samples += num_valid_tokens |
|
|
reached_max_samples = False |
|
|
|
|
|
|
|
|
valid_token_indices = attention_mask_flat.nonzero(as_tuple=False).view(-1) |
|
|
|
|
|
valid_token_indices = valid_token_indices[:tokens_to_process] |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
|
|
|
process_hidden_states( |
|
|
layer_idx, |
|
|
hidden_states[layer_idx], |
|
|
valid_token_indices, |
|
|
hidden_size, |
|
|
hidden_states_files, |
|
|
hidden_states_counters |
|
|
) |
|
|
if args.mlp_activation: |
|
|
|
|
|
process_mlp_activations( |
|
|
layer_idx, |
|
|
mlp_activations[layer_idx], |
|
|
valid_token_indices, |
|
|
hidden_size, |
|
|
mlp_activations_files, |
|
|
mlp_activations_counters |
|
|
) |
|
|
|
|
|
if args.attn_norm: |
|
|
|
|
|
process_attn_norms( |
|
|
layer_idx, |
|
|
attn_outputs[layer_idx], |
|
|
valid_token_indices, |
|
|
num_heads, |
|
|
attn_norms_files, |
|
|
attn_norms_counters |
|
|
) |
|
|
|
|
|
return total_samples, reached_max_samples, num_valid_tokens |
|
|
|
|
|
def finalize_data_collection( |
|
|
data_dir, |
|
|
num_layers, |
|
|
hidden_size, |
|
|
num_heads, |
|
|
max_samples, |
|
|
hidden_states_files, |
|
|
mlp_activations_files, |
|
|
attn_norms_files, |
|
|
hidden_states_counters, |
|
|
mlp_activations_counters, |
|
|
attn_norms_counters, |
|
|
args |
|
|
): |
|
|
""" |
|
|
Finalize the data collection by flushing and closing mmap files and saving metadata. |
|
|
|
|
|
Args: |
|
|
data_dir (str): Directory where data is stored. |
|
|
num_layers (int): Number of transformer layers. |
|
|
hidden_size (int): Hidden size of the model. |
|
|
num_heads (int): Number of attention heads. |
|
|
max_samples (int): Maximum number of samples to collect. |
|
|
hidden_states_files (list): List of mmap files for hidden states. |
|
|
mlp_activations_files (list): List of mmap files for MLP activations. |
|
|
attn_norms_files (list): List of mmap files for attention norms. |
|
|
hidden_states_counters (list): List of counters for hidden states. |
|
|
mlp_activations_counters (list): List of counters for MLP activations. |
|
|
attn_norms_counters (list): List of counters for attention norms. |
|
|
""" |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
|
|
|
hs_file = hidden_states_files[layer_idx] |
|
|
hs_file.flush() |
|
|
del hs_file |
|
|
|
|
|
if args.mlp_activation: |
|
|
|
|
|
mlp_file = mlp_activations_files[layer_idx] |
|
|
mlp_file.flush() |
|
|
del mlp_file |
|
|
|
|
|
if args.attn_norm: |
|
|
|
|
|
attn_file = attn_norms_files[layer_idx] |
|
|
attn_file.flush() |
|
|
del attn_file |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'num_layers': num_layers, |
|
|
'hidden_size': hidden_size, |
|
|
'num_heads': num_heads, |
|
|
'max_samples': max_samples, |
|
|
'hidden_states_counters': hidden_states_counters, |
|
|
'mlp_activations_counters': mlp_activations_counters, |
|
|
'attn_norms_counters': attn_norms_counters |
|
|
} |
|
|
|
|
|
|
|
|
metadata_filename = os.path.join(data_dir, 'metadata.json') |
|
|
with open(metadata_filename, 'w') as f: |
|
|
json.dump(metadata, f) |
|
|
|
|
|
print("Finalization complete. Metadata saved.") |
|
|
|
|
|
def arg_parser(): |
|
|
parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation') |
|
|
parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate') |
|
|
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for evaluation') |
|
|
parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length') |
|
|
parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds') |
|
|
parser.add_argument('--device_map', type=str, default='cuda:0', help='Device to use for evaluation') |
|
|
parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection') |
|
|
parser.add_argument('--data_dir', type=str, default='<PATH_TO_DATA_DIR>', help='Directory to store generated data') |
|
|
parser.add_argument('--max_samples', type=int, default=5000, help='Maximum number of samples to collect') |
|
|
parser.add_argument('--model_family', type=str, default='opt', choices= ["opt", "llama"], help='Model family to evaluate') |
|
|
parser.add_argument('--mlp_activation', type=bool, default=False, help='Collect MLP activations') |
|
|
parser.add_argument('--attn_norm', type=bool, default=True, help='Collect attention norms') |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ =="__main__": |
|
|
args = arg_parser() |
|
|
model_name = MODELS[args.model_index-1] |
|
|
batch_size = args.batch_size |
|
|
max_length = args.max_length |
|
|
data_collection = args.data_collection |
|
|
device_map = args.device_map |
|
|
|
|
|
|
|
|
if args.model_family == 'opt': |
|
|
model = OPTForCausalLM.from_pretrained( |
|
|
model_name, device_map=device_map, torch_dtype=torch.float16, |
|
|
attn_implementation="flash_attention_2", output_hidden_states=True, output_attentions=True, |
|
|
return_dict=True |
|
|
) |
|
|
num_neurons = model.config.ffn_dim |
|
|
|
|
|
elif args.model_family == 'llama': |
|
|
model = LlamaForCausalLM.from_pretrained( |
|
|
model_name, device_map=device_map, torch_dtype=torch.float16, |
|
|
attn_implementation="flash_attention_2", output_hidden_states=True, output_attentions=True, |
|
|
return_dict=True |
|
|
) |
|
|
num_neurons = model.config.intermediate_size |
|
|
|
|
|
data_loader = build_data_loader( |
|
|
model_name, "wikitext", "wikitext-2-raw-v1", batch_size, max_length, split='train' |
|
|
) |
|
|
if args.device_map == "auto": |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
else: |
|
|
device = torch.device(device_map if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
model_name_clean = extract_model_name(model_name) |
|
|
folder_name = f"{model_name_clean}_act_data" |
|
|
data_dir = os.path.join(args.data_dir, folder_name) |
|
|
|
|
|
if not os.path.exists(data_dir): |
|
|
os.makedirs(data_dir) |
|
|
|
|
|
num_layers = model.config.num_hidden_layers |
|
|
hidden_size = model.config.hidden_size |
|
|
num_heads = model.config.num_attention_heads |
|
|
max_samples = args.max_samples |
|
|
|
|
|
|
|
|
print(f"Collecting data for model: {model_name}") |
|
|
print(f"Data directory: {data_dir}") |
|
|
print(f"Number of layers: {num_layers}") |
|
|
print(f"Hidden size: {hidden_size}") |
|
|
print(f"Number of heads: {num_heads}") |
|
|
print(f"Number of neurons: {num_neurons}") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"Collecting MLP activations: {args.mlp_activation}") |
|
|
print(f"Collecting attention norms: {args.attn_norm}") |
|
|
|
|
|
|
|
|
(hidden_states_files, hidden_states_counters, mlp_activations_files, |
|
|
mlp_activations_counters, attn_norms_files, attn_norms_counters) = initialize_data_structures(data_dir, num_layers, |
|
|
hidden_size, num_heads, num_neurons, max_samples, |
|
|
mlp_activation=args.mlp_activation, attn_norm=args.attn_norm) |
|
|
|
|
|
total_samples = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
with tqdm(total=max_samples, desc="Router training data collection") as pbar: |
|
|
for batch in data_loader: |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, |
|
|
output_attentions=False, return_dict=True, output_mlp_activation=args.mlp_activation, |
|
|
output_attn_output=args.attn_norm, output_router_inputs=True) |
|
|
|
|
|
|
|
|
total_samples, reached_max_samples, num_valid_tokens = process_batch( |
|
|
outputs, input_ids, attention_mask, |
|
|
total_samples, max_samples, num_layers, |
|
|
hidden_size, num_heads, hidden_states_files, |
|
|
hidden_states_counters, mlp_activations_files, mlp_activations_counters, |
|
|
attn_norms_files, attn_norms_counters, |
|
|
args=args) |
|
|
|
|
|
pbar.update(num_valid_tokens) |
|
|
|
|
|
if reached_max_samples: |
|
|
break |
|
|
|
|
|
|
|
|
finalize_data_collection( |
|
|
data_dir, num_layers, hidden_size, |
|
|
num_heads, max_samples, hidden_states_files, |
|
|
mlp_activations_files, attn_norms_files, hidden_states_counters, |
|
|
mlp_activations_counters, attn_norms_counters, |
|
|
args |
|
|
) |
|
|
|
|
|
if reached_max_samples: |
|
|
print(f"Reached maximum number of samples. total_samples = {total_samples}") |
|
|
print(f"Data collection complete. Data saved to {data_dir}") |