From 49118aad42848d946c6539a95437db0898a1fd57 Mon Sep 17 00:00:00 2001 From: LiangSong Date: Thu, 27 Apr 2023 23:42:11 +0800 Subject: [PATCH] update header config and add padding to concat_multiple_sequence --- chat_server.py | 2 +- configs/instruct_config.yaml | 3 ++- configs/pretrain_config.yaml | 5 +++-- dataset/dataset.py | 38 +++++++++++++++++------------------- solver/trainer.py | 2 +- train_lm.py | 6 ++++-- 6 files changed, 29 insertions(+), 27 deletions(-) diff --git a/chat_server.py b/chat_server.py index 2361687..213a8ea 100644 --- a/chat_server.py +++ b/chat_server.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-06 22:30:10 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-26 23:58:23 +LastEditTime: 2023-04-27 20:34:58 FilePath: /Open-Llama/chat_server.py Description: diff --git a/configs/instruct_config.yaml b/configs/instruct_config.yaml index 23787f6..50bfb51 100644 --- a/configs/instruct_config.yaml +++ b/configs/instruct_config.yaml @@ -3,7 +3,7 @@ data: data: mixed: "data/instruction_data/part-*.jsonl.zst" pad_to_max: False - sequence_sample_mode: "sample" + sequence_sample_mode: "none" concat_multiple_sequence: True num_sequences: 50 seq_length: 2048 @@ -24,6 +24,7 @@ train: ckpt: "data/llama_raw_ckpt/7B/extended.pth" train_num_workers: 16 gradient_accumulation_steps: 1 + prefetch_factor: 100 # global step log_interval: 50 eval_interval: 500 diff --git a/configs/pretrain_config.yaml b/configs/pretrain_config.yaml index acd9791..94dca87 100644 --- a/configs/pretrain_config.yaml +++ b/configs/pretrain_config.yaml @@ -3,9 +3,9 @@ data: data: mixed: "data/pretrain_data/part-*.jsonl.zst" pad_to_max: False - sequence_sample_mode: "sample" + sequence_sample_mode: "none" concat_multiple_sequence: True - num_sequences: 20 + num_sequences: 10 seq_length: 2048 tokenizer_model_path: "configs/llama_tokenizer_extended.model" model: @@ -24,6 +24,7 @@ train: ckpt: "data/llama_raw_ckpt/7B/extended.pth" train_num_workers: 16 gradient_accumulation_steps: 12 + prefetch_factor: 100 # global step log_interval: 5 eval_interval: 200 diff --git a/dataset/dataset.py b/dataset/dataset.py index 99b2eb0..47415ec 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -2,12 +2,13 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-24 20:05:59 +LastEditTime: 2023-04-27 22:19:37 FilePath: /Open-Llama/dataset/dataset.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ +import math import torch import random from glob import glob @@ -98,13 +99,6 @@ def split_multiturn(batch): return {"text": batch["text"][0].split("[multiturn_sep]")} -def truncation_gen(seq_length): - def truncation(line): - return {"input_ids": line["input_ids"][:seq_length]} - - return truncation - - def sample_sequence_gen(seq_length, eos_token_id): def sample_sequence(line): doc_length = line["input_ids"].shape[0] @@ -134,16 +128,15 @@ def split_sequence_gen(seq_length): return split_sequence -def concat_multiple_sequence_gen(seq_length): +def concat_multiple_sequence_gen(seq_length, pad_token_id): def concat_multiple_sequence(batch): concat_input_ids = torch.cat(batch["input_ids"], dim=0) - input_ids = [] - while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length: - input_ids.append( - concat_input_ids[ - len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length - ] - ) + length = concat_input_ids.shape[0] + chunks = math.ceil(length / seq_length) + pad_length = chunks * seq_length - length + pad = torch.ones(pad_length, dtype=concat_input_ids.dtype) * pad_token_id + concat_input_ids = torch.cat([concat_input_ids, pad], dim=0) + input_ids = torch.chunk(concat_input_ids, chunks) return {"input_ids": input_ids} return concat_multiple_sequence @@ -170,6 +163,8 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False): dataset = load_dataset( "json", data_files=data_files, split="train", streaming=True ) + # shuffle + dataset = dataset.shuffle() # 文本预处理转换为统一格式 if dataset_config["mode"] == "pretrain": dataset = dataset.map(pretrain_transform, batched=True, batch_size=1) @@ -198,10 +193,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False): return full_dataset seq_length = dataset_config["seq_length"] + pad_to_max = dataset_config.get("pad_to_max", True) sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation") truncation = sequence_sample_mode == "truncation" + concat_multiple_sequence = dataset_config.get("concat_multiple_sequence", False) # tokenize - if dataset_config.get("pad_to_max", True): + if pad_to_max: full_dataset = full_dataset.map( lambda x: tokenizer( x["text"], @@ -232,11 +229,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False): elif sequence_sample_mode == "none": pass elif sequence_sample_mode == "sample": + assert pad_to_max or concat_multiple_sequence full_dataset = full_dataset.map( sample_sequence_gen(seq_length, tokenizer.eos_token_id) ) elif sequence_sample_mode == "split": - assert not dataset_config.get("concat_multiple_sequence", False) + assert not concat_multiple_sequence full_dataset = full_dataset.map( split_sequence_gen(seq_length), batched=True, batch_size=1 ) @@ -246,10 +244,10 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False): ) # concat multiple sequence - if dataset_config.get("concat_multiple_sequence", False): + if concat_multiple_sequence: num_sequences = dataset_config["num_sequences"] full_dataset = full_dataset.map( - concat_multiple_sequence_gen(seq_length), + concat_multiple_sequence_gen(seq_length, tokenizer.pad_token_id), batched=True, batch_size=num_sequences, drop_last_batch=True, diff --git a/solver/trainer.py b/solver/trainer.py index a10cbeb..a84f29f 100644 --- a/solver/trainer.py +++ b/solver/trainer.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-26 23:06:55 +LastEditTime: 2023-04-27 20:34:47 FilePath: /Open-Llama/solver/trainer.py Description: diff --git a/train_lm.py b/train_lm.py index 09699a7..fb0d9d2 100644 --- a/train_lm.py +++ b/train_lm.py @@ -2,8 +2,8 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-12 19:12:42 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-26 23:05:47 -FilePath: /Open-Llama/pretrain.py +LastEditTime: 2023-04-27 23:08:47 +FilePath: /Open-Llama/train_lm.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. @@ -50,6 +50,8 @@ def main(argv): train_dataset, batch_size=config["train"]["train_batch_size"], num_workers=config["train"]["train_num_workers"], + prefetch_factor=config["train"].get("prefetch_factor", 2), + pin_memory=True, ) # smaller initializer_range make training more stable # add stabel embedding to token embedding