|
@@ -115,19 +115,21 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
|
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
|
|
ngram_prompt_lookup_max)
|
|
|
- elif draft_worker_kwargs[
|
|
|
- "model_config"].hf_config.model_type == "mlp_speculator":
|
|
|
- proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
|
|
- disable_bonus_tokens = False
|
|
|
else:
|
|
|
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
|
|
'parallel_config']
|
|
|
draft_tp = draft_parallel_config.tensor_parallel_size
|
|
|
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
|
|
|
|
|
- if draft_tp == 1:
|
|
|
- draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
|
|
- proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
|
|
+ if draft_worker_kwargs[
|
|
|
+ "model_config"].hf_config.model_type == "mlp_speculator":
|
|
|
+ disable_bonus_tokens = False
|
|
|
+ proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
|
|
+ else:
|
|
|
+ if draft_tp == 1:
|
|
|
+ draft_worker_kwargs[
|
|
|
+ "model_runner_cls"] = TP1DraftModelRunner
|
|
|
+ proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
|
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
|
|
proposer_worker, draft_tp, target_tp)
|
|
|
|