add instruction-tuning
This commit is contained in:
		
							parent
							
								
									e1bd1766bc
								
							
						
					
					
						commit
						a62ac2658f
					
				
							
								
								
									
										25
									
								
								configs/instruction_tuning_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								configs/instruction_tuning_config.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,25 @@ | ||||||
|  | """ | ||||||
|  | Author: LiangSong(sl12160010@gmail.com) | ||||||
|  | Date: 2023-03-30 21:38:07 | ||||||
|  | LastEditors: LiangSong(sl12160010@gmail.com) | ||||||
|  | LastEditTime: 2023-03-30 21:39:40 | ||||||
|  | FilePath: /Open-Llama/configs/instruction_tuning_config.py | ||||||
|  | Description:  | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||||
|  | """ | ||||||
|  | max_length = 1024 | ||||||
|  | train_batch_size = 2 | ||||||
|  | num_training_steps = 37500 | ||||||
|  | num_warmup_steps = 100 | ||||||
|  | initializer_range = 1e-2 | ||||||
|  | lr = 2e-4 | ||||||
|  | weight_decay = 1e-1 | ||||||
|  | tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model" | ||||||
|  | patterns = ["data/instruction_data/part-*.jsonl.zst"] | ||||||
|  | # global step | ||||||
|  | log_interval = 50 | ||||||
|  | eval_interval = 500 | ||||||
|  | save_interval = 1000 | ||||||
|  | work_dir = "data/saved_ckpt/" | ||||||
|  | ckpt_path = "data/saved_ckpt/30000.pt" | ||||||
							
								
								
									
										61
									
								
								data/preprocess_instruction.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								data/preprocess_instruction.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,61 @@ | ||||||
|  | """ | ||||||
|  | Author: LiangSong(sl12160010@gmail.com) | ||||||
|  | Date: 2023-03-30 20:52:10 | ||||||
|  | LastEditors: LiangSong(sl12160010@gmail.com) | ||||||
|  | LastEditTime: 2023-03-30 20:52:12 | ||||||
|  | FilePath: /Open-Llama/data/preprocess_instruction.py | ||||||
|  | Description:  | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||||
|  | """ | ||||||
|  | import json | ||||||
|  | import zstandard as zstd | ||||||
|  | from datasets import load_dataset | ||||||
|  | 
 | ||||||
|  | dataset = load_dataset("yizhongw/self_instruct") | ||||||
|  | write_path = "data/instruction_data/part-self_instruct-{}.jsonl.zst" | ||||||
|  | total_num = 0 | ||||||
|  | file_num = 0 | ||||||
|  | wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8") | ||||||
|  | for line in dataset["train"]: | ||||||
|  |     line = json.dumps(line) | ||||||
|  |     if total_num % 1024 == 0 and total_num > 0: | ||||||
|  |         file_num += 1 | ||||||
|  |         wfp.close() | ||||||
|  |         wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8") | ||||||
|  |     wfp.write(line.encode("utf-8")) | ||||||
|  |     wfp.write(b"\n") | ||||||
|  |     total_num += 1 | ||||||
|  | wfp.close() | ||||||
|  | 
 | ||||||
|  | dataset = load_dataset("BelleGroup/generated_train_0.5M_CN") | ||||||
|  | write_path = "data/instruction_data/part-belle_0.5M-{}.jsonl.zst" | ||||||
|  | total_num = 0 | ||||||
|  | file_num = 0 | ||||||
|  | wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8") | ||||||
|  | for line in dataset["train"]: | ||||||
|  |     line = json.dumps(line) | ||||||
|  |     if total_num % 1024 == 0 and total_num > 0: | ||||||
|  |         file_num += 1 | ||||||
|  |         wfp.close() | ||||||
|  |         wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8") | ||||||
|  |     wfp.write(line.encode("utf-8")) | ||||||
|  |     wfp.write(b"\n") | ||||||
|  |     total_num += 1 | ||||||
|  | wfp.close() | ||||||
|  | 
 | ||||||
|  | dataset = load_dataset("BelleGroup/generated_train_1M_CN") | ||||||
|  | write_path = "data/instruction_data/part-belle_1M-{}.jsonl.zst" | ||||||
|  | total_num = 0 | ||||||
|  | file_num = 0 | ||||||
|  | wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8") | ||||||
|  | for line in dataset["train"]: | ||||||
|  |     line = json.dumps(line) | ||||||
|  |     if total_num % 1024 == 0 and total_num > 0: | ||||||
|  |         file_num += 1 | ||||||
|  |         wfp.close() | ||||||
|  |         wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8") | ||||||
|  |     wfp.write(line.encode("utf-8")) | ||||||
|  |     wfp.write(b"\n") | ||||||
|  |     total_num += 1 | ||||||
|  | wfp.close() | ||||||
							
								
								
									
										192
									
								
								dataset/data_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								dataset/data_loader.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,192 @@ | ||||||
|  | """ | ||||||
|  | Author: LiangSong(sl12160010@gmail.com) | ||||||
|  | Date: 2023-03-30 20:58:16 | ||||||
|  | LastEditors: LiangSong(sl12160010@gmail.com) | ||||||
|  | LastEditTime: 2023-03-30 21:00:49 | ||||||
|  | FilePath: /Open-Llama/dataset/data_loader.py | ||||||
|  | Description:  | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||||
|  | """ | ||||||
|  | import math | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"): | ||||||
|  |     """ | ||||||
|  |     Organize data into tensors by padding based on the preset maximum length. | ||||||
|  |     """ | ||||||
|  |     pad_id = tokenizer.pad_id | ||||||
|  | 
 | ||||||
|  |     def pretrain_collate_fn(batch): | ||||||
|  |         if padding == "longest": | ||||||
|  |             max_length = max([len(i) for i in batch]) | ||||||
|  |         elif padding == "max_length": | ||||||
|  |             max_length = segment_max_length | ||||||
|  |         else: | ||||||
|  |             raise Exception("Invalid argumet for padding: {}".format(padding)) | ||||||
|  |         input_ids = [] | ||||||
|  |         for i in batch: | ||||||
|  |             input_len = len(i) | ||||||
|  |             input_ids.append(i + [pad_id] * (max_length - input_len)) | ||||||
|  |         inputs = { | ||||||
|  |             "input_ids": torch.tensor(input_ids, dtype=torch.int64), | ||||||
|  |         } | ||||||
|  |         return inputs | ||||||
|  | 
 | ||||||
|  |     return pretrain_collate_fn | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BySequenceLengthDataset(torch.utils.data.IterableDataset): | ||||||
|  |     """ | ||||||
|  |     experimental | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024 | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.generator = generator | ||||||
|  |         self.batch_size = batch_size | ||||||
|  |         self.bucket_size = bucket_size | ||||||
|  |         self.bucket_num = math.ceil(max_length / bucket_size) | ||||||
|  |         self.buckets = [[] for _ in range(self.bucket_num)] | ||||||
|  |         self.bucket_idx = None | ||||||
|  |         self.accelerator = accelerator | ||||||
|  |         if self.accelerator is not None: | ||||||
|  |             self.buckets_ele_num = torch.tensor( | ||||||
|  |                 [0] * self.bucket_num, dtype=torch.int64, device=accelerator.device | ||||||
|  |             ) | ||||||
|  |             self.buckets_indexes = torch.arange( | ||||||
|  |                 self.bucket_num, device=accelerator.device | ||||||
|  |             ) | ||||||
|  |         self.finished = False | ||||||
|  |         self.has_no_same_bucket = False | ||||||
|  |         self.rest = None | ||||||
|  | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         if self.batch_size <= 1: | ||||||
|  |             return self.generator | ||||||
|  | 
 | ||||||
|  |         def bucket_iter(): | ||||||
|  |             while True: | ||||||
|  |                 if self.bucket_idx is not None: | ||||||
|  |                     sample = self.buckets[self.bucket_idx].pop() | ||||||
|  |                     if len(self.buckets[self.bucket_idx]) == 0: | ||||||
|  |                         self.bucket_idx = None | ||||||
|  |                     yield sample | ||||||
|  |                 try: | ||||||
|  |                     sample = next(self.generator) | ||||||
|  |                 except StopIteration: | ||||||
|  |                     break | ||||||
|  |                 sample_len = len(sample) - 1 | ||||||
|  |                 bucket_idx = sample_len // self.bucket_size | ||||||
|  |                 if len(self.buckets[bucket_idx]) == self.batch_size - 1: | ||||||
|  |                     self.bucket_idx = bucket_idx | ||||||
|  |                     yield sample | ||||||
|  |                 else: | ||||||
|  |                     self.buckets[bucket_idx].append(sample) | ||||||
|  | 
 | ||||||
|  |         def parallel_bucket_iter(): | ||||||
|  |             while True: | ||||||
|  |                 if self.bucket_idx is not None: | ||||||
|  |                     sample = self.buckets[self.bucket_idx].pop() | ||||||
|  |                     self.buckets_ele_num[self.bucket_idx] -= 1 | ||||||
|  |                     buckets_ele_num = self.accelerator.gather(self.buckets_ele_num) | ||||||
|  |                     buckets_ele_num = buckets_ele_num.reshape( | ||||||
|  |                         self.accelerator.num_processes, self.bucket_num | ||||||
|  |                     ) | ||||||
|  |                     min_buckets_ele_num = buckets_ele_num.min(dim=0)[0] | ||||||
|  |                     if min_buckets_ele_num[self.bucket_idx] <= 0: | ||||||
|  |                         self.bucket_idx = None | ||||||
|  |                     yield sample | ||||||
|  |                 else: | ||||||
|  |                     if self.finished: | ||||||
|  |                         if self.has_no_same_bucket: | ||||||
|  |                             if self.rest is None: | ||||||
|  |                                 self.rest = [] | ||||||
|  |                                 for bucket in self.buckets: | ||||||
|  |                                     for i in bucket: | ||||||
|  |                                         self.rest.append(i) | ||||||
|  |                             elif len(self.rest) > 0: | ||||||
|  |                                 yield self.rest.pop() | ||||||
|  |                             else: | ||||||
|  |                                 raise StopIteration() | ||||||
|  |                         else: | ||||||
|  |                             buckets_ele_num = self.accelerator.gather( | ||||||
|  |                                 self.buckets_ele_num | ||||||
|  |                             ) | ||||||
|  |                             buckets_ele_num = buckets_ele_num.view( | ||||||
|  |                                 self.accelerator.num_processes, self.bucket_num | ||||||
|  |                             ) | ||||||
|  |                             min_buckets_ele_num = buckets_ele_num.min(dim=0)[0] | ||||||
|  |                             valid_bucket_idx = self.buckets_indexes[ | ||||||
|  |                                 min_buckets_ele_num >= self.batch_size | ||||||
|  |                             ] | ||||||
|  |                             if len(valid_bucket_idx) > 0: | ||||||
|  |                                 self.bucket_idx = valid_bucket_idx[0].cpu().item() | ||||||
|  |                             else: | ||||||
|  |                                 self.has_no_same_bucket = True | ||||||
|  |                     else: | ||||||
|  |                         try: | ||||||
|  |                             sample = next(self.generator) | ||||||
|  |                         except StopIteration: | ||||||
|  |                             self.finished = True | ||||||
|  |                             continue | ||||||
|  |                         sample_len = len(sample) - 1 | ||||||
|  |                         bucket_idx = sample_len // self.bucket_size | ||||||
|  |                         self.buckets[bucket_idx].append(sample) | ||||||
|  |                         self.buckets_ele_num[bucket_idx] += 1 | ||||||
|  |                         buckets_ele_num = self.accelerator.gather( | ||||||
|  |                             self.buckets_ele_num | ||||||
|  |                         ).cpu() | ||||||
|  |                         buckets_ele_num = buckets_ele_num.view( | ||||||
|  |                             self.accelerator.num_processes, self.bucket_num | ||||||
|  |                         ) | ||||||
|  |                         min_buckets_ele_num = buckets_ele_num.min(dim=0)[0] | ||||||
|  |                         valid_bucket_idx = self.buckets_indexes[ | ||||||
|  |                             min_buckets_ele_num >= self.batch_size | ||||||
|  |                         ] | ||||||
|  |                         if len(valid_bucket_idx) > 0: | ||||||
|  |                             self.bucket_idx = valid_bucket_idx[0].cpu().item() | ||||||
|  | 
 | ||||||
|  |         if self.accelerator: | ||||||
|  |             return parallel_bucket_iter() | ||||||
|  |         else: | ||||||
|  |             return bucket_iter() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     import sentencepiece as spm | ||||||
|  |     from datasets import IterableDataset | ||||||
|  |     from torch.utils.data import DataLoader | ||||||
|  | 
 | ||||||
|  |     from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen | ||||||
|  | 
 | ||||||
|  |     from dataset.tokenizer import Tokenizer | ||||||
|  |     from dataset.data_iter import create_shard_kwargs, create_data_iter | ||||||
|  | 
 | ||||||
|  |     sp_model = spm.SentencePieceProcessor( | ||||||
|  |         model_file="configs/10w_vocab_wudao5_pile10.model" | ||||||
|  |     ) | ||||||
|  |     tokenizer = Tokenizer(sp_model) | ||||||
|  |     patterns = ["data/pretrain_data/part-*.jsonl.zst"] | ||||||
|  |     paths = create_shard_kwargs(patterns) | ||||||
|  |     transform_dict = { | ||||||
|  |         "wudao": preprocess_wudao_gen(tokenizer), | ||||||
|  |         "pile": preprocess_the_pile_gen(tokenizer), | ||||||
|  |     } | ||||||
|  |     data_set = IterableDataset.from_generator( | ||||||
|  |         create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict} | ||||||
|  |     ) | ||||||
|  |     train_loader = DataLoader( | ||||||
|  |         data_set, | ||||||
|  |         batch_size=8, | ||||||
|  |         num_workers=4, | ||||||
|  |         collate_fn=pretrain_collate_fn_gen(tokenizer), | ||||||
|  |         drop_last=True, | ||||||
|  |     ) | ||||||
|  |     for batch in train_loader: | ||||||
|  |         for k, v in batch.items(): | ||||||
|  |             print(k, v.shape) | ||||||
|  |         break | ||||||
							
								
								
									
										75
									
								
								dataset/instruction_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								dataset/instruction_dataset.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,75 @@ | ||||||
|  | """ | ||||||
|  | Author: LiangSong(sl12160010@gmail.com) | ||||||
|  | Date: 2023-03-30 21:02:00 | ||||||
|  | LastEditors: LiangSong(sl12160010@gmail.com) | ||||||
|  | LastEditTime: 2023-03-30 21:02:06 | ||||||
|  | FilePath: /Open-Llama/dataset/instruction_dataset.py | ||||||
|  | Description:  | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024): | ||||||
|  |     def preprocess_self_instruction(line): | ||||||
|  |         """ | ||||||
|  |         The format of the data is roughly as follows. | ||||||
|  |         {'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'} | ||||||
|  |         Split the data based on the tokenized length according to the maximum length. | ||||||
|  |         """ | ||||||
|  |         prompt = line["prompt"] | ||||||
|  |         if prompt.endswith("Output:"): | ||||||
|  |             prompt = prompt[:-7] | ||||||
|  |         total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip()) | ||||||
|  |         out = tokenizer(total) | ||||||
|  |         input_ids = out["input_ids"] | ||||||
|  |         return [input_ids] | ||||||
|  | 
 | ||||||
|  |     return preprocess_self_instruction | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def preprocess_belle_gen(tokenizer, segment_max_length=1024): | ||||||
|  |     def preprocess_belle(line): | ||||||
|  |         """ | ||||||
|  |         The format of the data is roughly as follows. | ||||||
|  |         {'text': 'some text', 'meta': {'pile_set_name': 'Github'}} | ||||||
|  |         Split the data based on the tokenized length according to the maximum length. | ||||||
|  |         """ | ||||||
|  |         prompt = line["input"].replace("\\n", "") | ||||||
|  |         prompt = prompt.strip("") | ||||||
|  | 
 | ||||||
|  |         completion = line["target"].replace("\\n", "") | ||||||
|  |         completion = completion.strip("") | ||||||
|  |         total = "user:{}<s>system:{}".format(prompt, completion) | ||||||
|  |         out = tokenizer(total) | ||||||
|  |         input_ids = out["input_ids"] | ||||||
|  |         return [input_ids] | ||||||
|  | 
 | ||||||
|  |     return preprocess_belle | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     import sentencepiece as spm | ||||||
|  |     from datasets import IterableDataset | ||||||
|  | 
 | ||||||
|  |     from dataset.tokenizer import Tokenizer | ||||||
|  |     from dataset.data_iter import create_shard_kwargs, create_data_iter | ||||||
|  | 
 | ||||||
|  |     sp_model = spm.SentencePieceProcessor( | ||||||
|  |         model_file="configs/10w_vocab_wudao5_pile10.model" | ||||||
|  |     ) | ||||||
|  |     tokenizer = Tokenizer(sp_model) | ||||||
|  |     patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"] | ||||||
|  |     paths = create_shard_kwargs(patterns) | ||||||
|  |     transform_dict = { | ||||||
|  |         "belle_1M": preprocess_belle_gen(tokenizer), | ||||||
|  |         "belle_0.5M": preprocess_belle_gen(tokenizer), | ||||||
|  |         "self_instruct": preprocess_self_instruction_gen(tokenizer), | ||||||
|  |     } | ||||||
|  |     data_set = IterableDataset.from_generator( | ||||||
|  |         create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict} | ||||||
|  |     ) | ||||||
|  |     for i, sample in enumerate(data_set): | ||||||
|  |         print(sample, sp_model.Decode(sample)) | ||||||
|  |         if i == 20: | ||||||
|  |             break | ||||||
|  | @ -9,7 +9,6 @@ Description: | ||||||
| Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||||
| """ | """ | ||||||
| import math | import math | ||||||
| import torch |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def preprocess_wudao_gen(tokenizer, segment_max_length=1024): | def preprocess_wudao_gen(tokenizer, segment_max_length=1024): | ||||||
|  | @ -48,59 +47,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024): | ||||||
|     return preprocess_the_pile |     return preprocess_the_pile | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024): |  | ||||||
|     """ |  | ||||||
|     Organize data into tensors by padding based on the preset maximum length. |  | ||||||
|     """ |  | ||||||
|     pad_id = tokenizer.pad_id |  | ||||||
| 
 |  | ||||||
|     def pretrain_collate_fn(batch): |  | ||||||
|         input_ids = [] |  | ||||||
|         for i in batch: |  | ||||||
|             input_len = len(i) |  | ||||||
|             input_ids.append(i + [pad_id] * (segment_max_length - input_len)) |  | ||||||
|         inputs = { |  | ||||||
|             "input_ids": torch.tensor(input_ids, dtype=torch.int64), |  | ||||||
|         } |  | ||||||
|         return inputs |  | ||||||
| 
 |  | ||||||
|     return pretrain_collate_fn |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class BucketBySequenceLengthDataset(torch.utils.data.IterableDataset): |  | ||||||
|     def __init__(self, generator, batch_size, bucket_size=32, max_length=1024): |  | ||||||
|         super().__init__() |  | ||||||
|         self.generator = generator |  | ||||||
|         self.batch_size = batch_size |  | ||||||
|         self.bucket_size = bucket_size |  | ||||||
|         self.bucket_num = math.ceil(max_length / bucket_size) |  | ||||||
|         self.buckets = [[] for _ in range(self.bucket_num)] |  | ||||||
|         self.bucket_idx = None |  | ||||||
| 
 |  | ||||||
|     def __iter__(self): |  | ||||||
|         if self.batch_size <= 1: |  | ||||||
|             return self.generator |  | ||||||
|         def bucket_iter(): |  | ||||||
|             if self.bucket_idx is not None: |  | ||||||
|                 sample = self.buckets[self.bucket_idx].pop() |  | ||||||
|                 if len(self.buckets[self.bucket_idx]) == 0: |  | ||||||
|                     self.bucket_idx = None |  | ||||||
|                 yield sample |  | ||||||
|             sample = next(self.generator) - 1 |  | ||||||
|             sample_len = len(sample) |  | ||||||
|             bucket_idx = sample_len // self.bucket_size |  | ||||||
|             if len(self.buckets[bucket_idx]) == self.batch_size - 1: |  | ||||||
|                 self.bucket_idx = bucket_idx |  | ||||||
|                 yield sample |  | ||||||
|             else:  |  | ||||||
|                 self.buckets[bucket_idx].append(sample) |  | ||||||
|         return bucket_iter() |  | ||||||
|      |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import sentencepiece as spm |     import sentencepiece as spm | ||||||
|     from datasets import IterableDataset |     from datasets import IterableDataset | ||||||
|     from torch.utils.data import DataLoader |  | ||||||
| 
 | 
 | ||||||
|     from dataset.tokenizer import Tokenizer |     from dataset.tokenizer import Tokenizer | ||||||
|     from dataset.data_iter import create_shard_kwargs, create_data_iter |     from dataset.data_iter import create_shard_kwargs, create_data_iter | ||||||
|  | @ -118,14 +67,6 @@ if __name__ == "__main__": | ||||||
|     data_set = IterableDataset.from_generator( |     data_set = IterableDataset.from_generator( | ||||||
|         create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict} |         create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict} | ||||||
|     ) |     ) | ||||||
|     train_loader = DataLoader( |     for sample in data_set: | ||||||
|         data_set, |         print(sample) | ||||||
|         batch_size=8, |  | ||||||
|         num_workers=4, |  | ||||||
|         collate_fn=pretrain_collate_fn_gen(tokenizer), |  | ||||||
|         drop_last=True, |  | ||||||
|     ) |  | ||||||
|     for batch in train_loader: |  | ||||||
|         for k, v in batch.items(): |  | ||||||
|             print(k, v.shape) |  | ||||||
|         break |         break | ||||||
|  |  | ||||||
							
								
								
									
										180
									
								
								inctruction_tuning.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								inctruction_tuning.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,180 @@ | ||||||
|  | """ | ||||||
|  | Author: LiangSong(sl12160010@gmail.com) | ||||||
|  | Date: 2023-03-30 21:35:01 | ||||||
|  | LastEditors: LiangSong(sl12160010@gmail.com) | ||||||
|  | LastEditTime: 2023-03-30 21:40:03 | ||||||
|  | FilePath: /Open-Llama/inctruction_tuning.py | ||||||
|  | Description:  | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||||
|  | """ | ||||||
|  | import os | ||||||
|  | import time | ||||||
|  | import wandb | ||||||
|  | import torch | ||||||
|  | import random | ||||||
|  | import sentencepiece as spm | ||||||
|  | from torchinfo import summary | ||||||
|  | from accelerate import Accelerator | ||||||
|  | from datasets import IterableDataset | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | from deepspeed.ops.adam import FusedAdam | ||||||
|  | from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup | ||||||
|  | 
 | ||||||
|  | from dataset.validation import val_set | ||||||
|  | from dataset.tokenizer import Tokenizer | ||||||
|  | from dataset.data_iter import create_shard_kwargs, create_data_iter | ||||||
|  | from dataset.data_loader import pretrain_collate_fn_gen | ||||||
|  | from dataset.instruction_dataset import ( | ||||||
|  |     preprocess_belle_gen, | ||||||
|  |     preprocess_self_instruction_gen, | ||||||
|  | ) | ||||||
|  | from configs.instruction_tuning_config import * | ||||||
|  | 
 | ||||||
|  | accelerator = Accelerator() | ||||||
|  | 
 | ||||||
|  | if accelerator.is_main_process: | ||||||
|  |     wandb.init(project="LLAMA Instruction") | ||||||
|  | 
 | ||||||
|  | log_interval *= accelerator.gradient_accumulation_steps | ||||||
|  | eval_interval *= accelerator.gradient_accumulation_steps | ||||||
|  | save_interval *= accelerator.gradient_accumulation_steps | ||||||
|  | 
 | ||||||
|  | sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path) | ||||||
|  | tokenizer = Tokenizer(sp_model) | ||||||
|  | 
 | ||||||
|  | paths = create_shard_kwargs(patterns, repeat=3) | ||||||
|  | random.shuffle(paths) | ||||||
|  | transform_dict = { | ||||||
|  |     "belle_1M": preprocess_belle_gen(tokenizer, max_length), | ||||||
|  |     "belle_0.5M": preprocess_belle_gen(tokenizer, max_length), | ||||||
|  |     "self_instruct": preprocess_self_instruction_gen(tokenizer, max_length), | ||||||
|  | } | ||||||
|  | data_set = IterableDataset.from_generator( | ||||||
|  |     create_data_iter, | ||||||
|  |     gen_kwargs={ | ||||||
|  |         "paths": paths, | ||||||
|  |         "transform_dict": transform_dict, | ||||||
|  |         "process_index": accelerator.process_index, | ||||||
|  |         "num_processes": accelerator.num_processes, | ||||||
|  |     }, | ||||||
|  | ) | ||||||
|  | train_loader = DataLoader( | ||||||
|  |     data_set, | ||||||
|  |     batch_size=train_batch_size, | ||||||
|  |     # If num_workers is greater than 1, duplicate data may occur. | ||||||
|  |     num_workers=0, | ||||||
|  |     collate_fn=pretrain_collate_fn_gen(tokenizer, max_length), | ||||||
|  |     drop_last=True, | ||||||
|  | ) | ||||||
|  | # smaller initializer_range make training more stable | ||||||
|  | # add stabel embedding to token embedding | ||||||
|  | raw_model = LlamaForCausalLM( | ||||||
|  |     LlamaConfig( | ||||||
|  |         vocab_size=tokenizer.vocab_size, | ||||||
|  |         initializer_range=initializer_range, | ||||||
|  |         pad_token_id=tokenizer.pad_id, | ||||||
|  |         rms_norm_eps=1e-5, | ||||||
|  |         hidden_dropout_prob=0.1, | ||||||
|  |         attention_dropout_prob=0.1, | ||||||
|  |         use_stable_embedding=True, | ||||||
|  |         shared_input_output_embedding=True, | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | ckpt = torch.load(ckpt_path, map_location="cpu") | ||||||
|  | raw_model.load_state_dict(ckpt) | ||||||
|  | raw_model.eval() | ||||||
|  | with torch.no_grad(): | ||||||
|  |     summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda()) | ||||||
|  | no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"] | ||||||
|  | optimizer_grouped_parameters = [ | ||||||
|  |     { | ||||||
|  |         "params": [ | ||||||
|  |             p | ||||||
|  |             for n, p in raw_model.named_parameters() | ||||||
|  |             if not any(nd in n for nd in no_decay) | ||||||
|  |         ], | ||||||
|  |         "weight_decay": weight_decay, | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "params": [ | ||||||
|  |             p | ||||||
|  |             for n, p in raw_model.named_parameters() | ||||||
|  |             if any(nd in n for nd in no_decay) | ||||||
|  |         ], | ||||||
|  |         "weight_decay": 0.0, | ||||||
|  |     }, | ||||||
|  | ] | ||||||
|  | optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95)) | ||||||
|  | optim.zero_grad() | ||||||
|  | factor = accelerator.num_processes / accelerator.gradient_accumulation_steps | ||||||
|  | scheduler = get_cosine_schedule_with_warmup( | ||||||
|  |     optim, | ||||||
|  |     num_warmup_steps=num_warmup_steps * factor, | ||||||
|  |     num_training_steps=num_training_steps * factor, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | _, model, optim, scheduler = accelerator.prepare( | ||||||
|  |     train_loader, raw_model, optim, scheduler | ||||||
|  | ) | ||||||
|  | print("start training...") | ||||||
|  | train_loader_iter = iter(train_loader) | ||||||
|  | global_step = 0 | ||||||
|  | start_time = time.time() | ||||||
|  | for data_step in range(num_training_steps): | ||||||
|  |     model.train() | ||||||
|  |     with accelerator.accumulate(model): | ||||||
|  |         batch = next(train_loader_iter) | ||||||
|  |         for k, v in batch.items(): | ||||||
|  |             batch[k] = v.to(accelerator.device, non_blocking=True) | ||||||
|  |         out = model(**batch, labels=batch["input_ids"]) | ||||||
|  |         total_loss = out.loss | ||||||
|  |         losses = {"total_loss": total_loss} | ||||||
|  |         accelerator.backward(total_loss) | ||||||
|  |         optim.step() | ||||||
|  |         scheduler.step() | ||||||
|  |         optim.zero_grad() | ||||||
|  |         if accelerator.sync_gradients: | ||||||
|  |             global_step += 1 | ||||||
|  |     if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process: | ||||||
|  |         cost_time = time.time() - start_time | ||||||
|  |         start_time = time.time() | ||||||
|  |         tokens = train_batch_size * log_interval * max_length | ||||||
|  |         wandb.log({"Training/Token per second per gpu": tokens / cost_time}) | ||||||
|  |         for k, v in losses.items(): | ||||||
|  |             wandb.log({"Losses/{}".format(k): v}) | ||||||
|  |         current_lr = optim.param_groups[0]["lr"] | ||||||
|  |         wandb.log({"Training/LR": current_lr}) | ||||||
|  |         if optim.scaler is not None: | ||||||
|  |             wandb.log({"Training/Loss Scale": optim.scaler.get_scale()}) | ||||||
|  |         wandb.log({"Training/Data Step": data_step}) | ||||||
|  |         wandb.log({"Training/Global Step": global_step}) | ||||||
|  |         accelerator.print( | ||||||
|  |             "Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format( | ||||||
|  |                 global_step, data_step, losses["total_loss"], tokens / cost_time | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     if data_step % eval_interval == 0 and accelerator.is_main_process: | ||||||
|  |         text_table = wandb.Table(columns=["question", "pred"]) | ||||||
|  |         model.eval() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for data in val_set: | ||||||
|  |                 raw_inputs = data | ||||||
|  |                 inputs_len = len(raw_inputs) | ||||||
|  |                 inputs = tokenizer( | ||||||
|  |                     raw_inputs, return_tensors=True, add_special_tokens=False | ||||||
|  |                 ) | ||||||
|  |                 for k, v in inputs.items(): | ||||||
|  |                     inputs[k] = v.to(accelerator.device) | ||||||
|  |                 pred = model.generate( | ||||||
|  |                     **inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0 | ||||||
|  |                 ) | ||||||
|  |                 pred = tokenizer.decode(pred.cpu())[0] | ||||||
|  |                 pred = pred[inputs_len:] | ||||||
|  |                 text_table.add_data(raw_inputs, pred) | ||||||
|  |         wandb.log({"Predictions on {}".format(global_step): text_table}) | ||||||
|  |     if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process: | ||||||
|  |         if not os.path.isdir(work_dir): | ||||||
|  |             os.mkdir(work_dir) | ||||||
|  |         torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step)) | ||||||
|  | wandb.finish() | ||||||
|  | @ -24,12 +24,12 @@ from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with | ||||||
| from dataset.validation import val_set | from dataset.validation import val_set | ||||||
| from dataset.tokenizer import Tokenizer | from dataset.tokenizer import Tokenizer | ||||||
| from dataset.data_iter import create_shard_kwargs, create_data_iter | from dataset.data_iter import create_shard_kwargs, create_data_iter | ||||||
|  | from dataset.data_loader import pretrain_collate_fn_gen | ||||||
| from dataset.pretrain_dataset import ( | from dataset.pretrain_dataset import ( | ||||||
|     preprocess_the_pile_gen, |     preprocess_the_pile_gen, | ||||||
|     preprocess_wudao_gen, |     preprocess_wudao_gen, | ||||||
|     pretrain_collate_fn_gen, |  | ||||||
| ) | ) | ||||||
| from configs.train_config import * | from configs.pretrain_config import * | ||||||
| 
 | 
 | ||||||
| accelerator = Accelerator() | accelerator = Accelerator() | ||||||
| 
 | 
 | ||||||
|  | @ -62,7 +62,7 @@ train_loader = DataLoader( | ||||||
|     data_set, |     data_set, | ||||||
|     batch_size=train_batch_size, |     batch_size=train_batch_size, | ||||||
|     # If num_workers is greater than 1, duplicate data may occur. |     # If num_workers is greater than 1, duplicate data may occur. | ||||||
|     num_workers=1, |     num_workers=0, | ||||||
|     collate_fn=pretrain_collate_fn_gen(tokenizer, max_length), |     collate_fn=pretrain_collate_fn_gen(tokenizer, max_length), | ||||||
|     drop_last=True, |     drop_last=True, | ||||||
| ) | ) | ||||||
|  | @ -124,7 +124,7 @@ for data_step in range(num_training_steps): | ||||||
|         batch = next(train_loader_iter) |         batch = next(train_loader_iter) | ||||||
|         for k, v in batch.items(): |         for k, v in batch.items(): | ||||||
|             batch[k] = v.to(accelerator.device, non_blocking=True) |             batch[k] = v.to(accelerator.device, non_blocking=True) | ||||||
|         out = model(**batch, labels=batch['input_ids']) |         out = model(**batch, labels=batch["input_ids"]) | ||||||
|         total_loss = out.loss |         total_loss = out.loss | ||||||
|         losses = {"total_loss": total_loss} |         losses = {"total_loss": total_loss} | ||||||
|         accelerator.backward(total_loss) |         accelerator.backward(total_loss) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 LiangSong
						LiangSong