fix split by shard bug
This commit is contained in:
parent
4a1e7bb44b
commit
ec2b4d6ee7
18
configs/accelerate_configs/ds_stage2.yaml
Normal file
18
configs/accelerate_configs/ds_stage2.yaml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_multinode_launcher: standard
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: false
|
||||||
|
zero_stage: 2
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
use_cpu: false
|
18
configs/accelerate_configs/ds_stage3.yaml
Normal file
18
configs/accelerate_configs/ds_stage3.yaml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_multinode_launcher: standard
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
use_cpu: false
|
|
@ -184,7 +184,7 @@ def construct_dataset(
|
||||||
random.shuffle(all_data_files)
|
random.shuffle(all_data_files)
|
||||||
if world_size is not None:
|
if world_size is not None:
|
||||||
num_shards = len(all_data_files)
|
num_shards = len(all_data_files)
|
||||||
all_data_files = all_data_files[num_shards // world_size * world_size]
|
all_data_files = all_data_files[:num_shards // world_size * world_size]
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"json", data_files=all_data_files, split="train", streaming=True
|
"json", data_files=all_data_files, split="train", streaming=True
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user