MultiProcess with Pytorch (多进程模型)

背景

批量embedding 入向量库巨慢,只能一个个搞,不同格式的文件(pdf ,ppt,doc,txt等),有不同处理方式,特别pdf巨慢,因为调用OCR,所以想优化一下。

问题分析

这里面有很多难受的地方

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time

from framework import log
from vector import EmbeddingModel, get_vs_client, Knowledge
import torch.multiprocessing as mp

source_folder_path = 'xxxx'

"""
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

If multiprocessing.set_start_method() is called twice, it will raise RuntimeError("context has already been set"). This error can be suppressed by passing the force=True keyword argument to the function.
"""
mp.set_start_method('spawn', force=True)

embedding_model = 'm3e-base'
logger = log.get_file_logger(f'test', './log/test.log')
exist_files = []
embeddings = EmbeddingModel.load(embedding_model, logger)
vs_cli = get_vs_client(embeddings, 'doc_test', '127.0.0.1:32008', None, logger)
db = Knowledge(vs_cli, logger)


def handler_doc(file_path, i):
    docs = db.load_file(file_path, doc_id=f'{i + 1}', category_id=f'{i + 1}', category_path='')
    if docs is not None:
        docs = docs[:5]
        db.add_docs(docs)
        logger.info(f'{file_path, len(docs)} 个 doc')


def load_file():
    all_files = []
    for root, dirs, files in os.walk(source_folder_path):
        for file in files:
            all_files.append([root, file])
    logger.info(f'{len(all_files)} doc')
    # docs = []
    processes = []
    args = []
    for i in range(len(all_files)):
        root, file = all_files[i]
        file_path = os.path.join(root, file)
        if os.path.exists(file_path) and file_path not in exist_files:
            logger.info(f'第{i + 1} {file_path}')
            # handler_doc(file_path, i)
            # args.append((file_path, i))

            p = mp.Process(target=handler_doc, name=file_path, args=(file_path, i),
                           daemon=True)
            processes.append(p)
            p.start()
    solutions = []
    # with mp.Pool(processes=2) as pool:
    #     futures_res = pool.imap(handler_doc, args)
    #     for i in futures_res:
    #         pass
    #     while args:
    #         s = args.pop(0)
    #         print(s)
    #         try:
    #             res = futures_res.next(timeout=20)
    #
    #             # If successful (no time out), append the result
    #             solutions.append((s, res))
    #         except mp.context.TimeoutError:
    #             print(s, "err")
    #             break
    # for s in solutions:
    #     print(s)

    TIMEOUT = 20
    start = time.time()
    while time.time() - start <= TIMEOUT:
        if not any(p.is_alive() for p in processes):
            # All the processes are done, break now.
            break

        time.sleep(.1)  # Just to avoid hogging the CPU
    else:
        # We only enter this if we didn't 'break' above.
        print("timed out, killing all processes")
        try:
            for p in processes:
                if p.is_alive():
                    print(p.name)
                p.terminate()
                p.join()
        except Exception as e:
            logger.error(exc_info=True)
            for p in processes:
                print(p.name)
                p.kill()


if __name__ == '__main__':
    """
    RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.
    """
    # 
    load_file()

要使用CUDA多进程,需要使用spawn

Python3 关于Contexts and start methods,关于开启多进程的三种用法

spawn

The parent process starts a fresh Python interpreter process. The child process will only inherit those resources necessary to run the process object’s run() method. In particular, unnecessary file descriptors and handles from the parent process will not be inherited. Starting a process using this method is rather slow compared to using fork or forkserver.

Available on POSIX and Windows platforms. The default on Windows and macOS.

fork

The parent process uses os.fork() to fork the Python interpreter. The child process, when it begins, is effectively identical to the parent process. All resources of the parent are inherited by the child process. Note that safely forking a multithreaded process is problematic.

Available on POSIX systems. Currently the default on POSIX except macOS.

Note

The default start method will change away from fork in Python 3.14. Code that requires fork should explicitly specify that via get_context() or set_start_method().

Changed in version 3.12: If Python is able to detect that your process has multiple threads, the os.fork() function that this start method calls internally will raise a DeprecationWarning. Use a different start method. See the os.fork() documentation for further explanation.

forkserver

When the program starts and selects the forkserver start method, a server process is spawned. From then on, whenever a new process is needed, the parent process connects to the server and requests that it fork a new process. The fork server process is single threaded unless system libraries or preloaded imports spawn threads as a side-effect so it is generally safe for it to use os.fork(). No unnecessary resources are inherited.

Available on POSIX platforms which support passing file descriptors over Unix pipes such as Linux.

spawn 导致一些诡异的情况

  • 以上写法其实模型会多次加载,不像fork会直接继承父进程的内存对象,spwan它会从头跑一遍,main以上的代码

  • 开启多进程的方法必须在if __name__ == '__main__': (Protect Entry Point)下面,Python文档有提到

    Safe importing of main module

    Make sure that the main module can be safely imported by a new Python interpreter without causing unintended side effects (such as starting a new process).

    For example, using the spawn or forkserver start method running the following module would fail with a RuntimeError:

    1
    2
    3
    4
    5
    6
    7
    from multiprocessing import Process
    
    def foo():
        print('hello')
    
    p = Process(target=foo)
    p.start()
    

    Instead one should protect the “entry point” of the program by using if __name__ == '__main__': as follows:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    from multiprocessing import Process, freeze_support, set_start_method
    
    def foo():
        print('hello')
    
    if __name__ == '__main__':
        freeze_support()
        set_start_method('spawn')
        p = Process(target=foo)
        p.start()
    

    (The freeze_support() line can be omitted if the program will be run normally instead of frozen.)

    This allows the newly spawned Python interpreter to safely import the module and then run the module’s foo() function.

    Similar restrictions apply if a pool or manager is created in the main module.

    那这就非常蛋疼,就很难把他当成一个moudle去调用。

    为什么我们需要保护入口点?

    当使用“ spawn ”方法从主进程启动新进程时,我们必须保护入口点。原因是因为我们的Python脚本将由子进程自动为我们加载和导入。这是在新的子进程中执行我们的自定义代码和函数所必需的。如果入口点没有通过检查顶级环境的 if 语句习惯用法进行保护,则脚本将直接再次执行,而不是按预期运行新的子进程。保护入口点可确保程序仅启动一次,主进程的任务仅由主进程执行,而子进程不会执行。

这是使用fork,不会出现的情况。参考pytorch GPU,MULTIPROCESSING BEST PRACTICES,这里面其实也没怎么说清楚。第一个问题还可以用序列化,queue解决,但是第二个问题就不知道怎么搞

柳暗花明

sentence_transformers

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

class EmbeddingModel:
    logger = None
    embeddings = None

    @classmethod
    def load(cls, embedding_model: str, logger=None, load_embedding=True, ):
....
            cls.embeddings = HuggingFaceEmbeddings(
                model_name=model_name,
                model_kwargs={'device': DEVICE},
                encode_kwargs={
                    'normalize_embeddings': True, 
                    'batch_size': 512,  # 默认是32,增加适当的batch_size 提高处理速度
                    'show_progress_bar': True,
                },
            )
...
            return cls.embeddings
        return None

参考langchain 的HuggingFaceEmbeddings源码,里面还有multi_process参数,可以多进程GPU encode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
   def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Compute doc embeddings using a HuggingFace transformer model.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        import sentence_transformers

        texts = list(map(lambda x: x.replace("\n", " "), texts))
        if self.multi_process:
            pool = self.client.start_multi_process_pool()
            # 这个地方好奇怪,为什么不传encode_kwargs
            embeddings = self.client.encode_multi_process(texts, pool)
            sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
        else:
            embeddings = self.client.encode(texts, **self.encode_kwargs)

        return embeddings.tolist()

上面只是加快处理文件的其中一个环节embedding encode而已。但是这里面有值得参考的地方,可以解决loader慢的问题。

start_multi_process_pool

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch.multiprocessing as mp

def start_multi_process_pool(self, target_devices: List[str] = None):
       
        if target_devices is None:
            if torch.cuda.is_available():
                target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
            else:
                logger.info("CUDA is not available. Start 4 CPU worker")
                target_devices = ['cpu']*4

        logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
        
        
 #Alternatively, you can use get_context() to obtain a context object. Context objects have the same API as   the multiprocessing module, and allow one to use multiple start methods in the same program.
        ctx = mp.get_context('spawn')
        input_queue = ctx.Queue()
        output_queue = ctx.Queue()
        processes = []

        for cuda_id in target_devices:
            p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(cuda_id, self, input_queue, output_queue), daemon=True)
            p.start()
            processes.append(p)

        return {'input': input_queue, 'output': output_queue, 'processes': processes}
      

# 相当于一个GPU一个encode进程, 通过queue传参和返回
def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
        """
        Internal working process to encode sentences in multi-process setup
        """
        while True:
            try:
                id, batch_size, sentences = input_queue.get()
                embeddings = model.encode(sentences, device=target_device,  show_progress_bar=False, convert_to_numpy=True, batch_size=batch_size)
                results_queue.put([id, embeddings])
            except queue.Empty:
                break

可以看出这里的使用get_context可能可以解决上面的第二个问题,当时没有好好看文档

set_start_method () 函数来设置全局的启动方法,它只能在程序的主模块中调用一次,不能在子进程中调用,会影响后续创建的所有进程的启动方式。

get_context () 函数来获取一个特定的启动方法的上下文对象。它可以在任何地方调用,不受全局的启动方法的限制。它也接受一个字符串参数,表示要使用的启动方法的名称

根据这个issue ,还是有问题报freeze_support的错。

处理出入参encode_multi_process

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    def encode_multi_process(self, sentences: List[str], pool: Dict[str, object], batch_size: int = 32, chunk_size: int = None):
        """
        This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
        and sent to individual processes, which encode these on the different GPUs. This method is only suitable
        for encoding large sets of sentences
        """

        if chunk_size is None:
            chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000)

        logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")

        input_queue = pool['input']
        last_chunk_id = 0
        chunk = []

        for sentence in sentences:
            chunk.append(sentence)
            if len(chunk) >= chunk_size:
                input_queue.put([last_chunk_id, batch_size, chunk])
                last_chunk_id += 1
                chunk = []

        if len(chunk) > 0:
            input_queue.put([last_chunk_id, batch_size, chunk])
            last_chunk_id += 1

        output_queue = pool['output']
        results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
        embeddings = np.concatenate([result[1] for result in results_list])
        return embeddings