#pragma once
#include "../epilogue_params.cuh"
#include "ep_policy.cuh"
namespace epi {
// Params mapping
template<typename T> struct EpParamsT;
template<> struct EpParamsT<float>{ using type = EpParamsF32; };
template<> struct EpParamsT<half> { using type = EpParamsF16; };
template<typename Policy>
struct EpApply {
using T = typename Policy::ElemT;
using P = typename EpParamsT<T>::type;
__device__ static inline void run(const P& p,
int m, int n, int ix, int iy,
const PhiloxState& st,
unsigned long long elem_idx){
T v = p.x[ix];
// bias -> act -> dropout
v = Policy::BiasF::template apply<T>(v, p.bias, n);
v = Policy::ActF ::template apply<T>(v);
if constexpr (Policy::UseDrop) {
v = Policy::DropF::template apply<T>(v, st, elem_idx, p.p_drop, p.keep_scale);
}
// blend & optional residual
Policy::BlendF::template store<T,float>(p.alpha, p.beta, v, p.y, iy);
if constexpr (Policy::UseResid) {
p.y[iy] = Math<T>::add(p.y[iy], p.resid[iy]);
}
}
};
} // namespace epi
- EpParamsT<T> 로 T = {float, half} 에 따라 파라미터 구조 매핑
- EpApply<Policy>::run(...) 는 한 원소 단위 에필로그 처리 파이프라인
- v = p.x[ix]
- v = BiasF.apply<T>(v, p.bias, n)
- v = ActF.apply<T>(v)
- v = DropF.apply<T>(v, st, elem_idx, p.p_drop, p.keep_scale)
- BlendF.store<T, float>(p.alpha, p.beta, v, p.t, iy)
- p.y[iy] += p.resid[iy]
- 인덱싱 : ix 는 x 에서 읽기, iy 는 y 에 쓰기, n 은 열 인덱스, m 은 현재 코드 경로에서 미사용