fix split by shard bug

This commit is contained in:
LiangSong 2023-05-08 14:03:05 +08:00
parent 4a1e7bb44b
commit ec2b4d6ee7
3 changed files with 37 additions and 1 deletions

View File

@ -0,0 +1,18 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
use_cpu: false

View File

@ -0,0 +1,18 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 3
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
use_cpu: false

View File

@ -184,7 +184,7 @@ def construct_dataset(
random.shuffle(all_data_files)
if world_size is not None:
num_shards = len(all_data_files)
all_data_files = all_data_files[num_shards // world_size * world_size]
all_data_files = all_data_files[:num_shards // world_size * world_size]
dataset = load_dataset(
"json", data_files=all_data_files, split="train", streaming=True
)