Files
ANSCORE/core/ANSLibsLoader/EPLoader.cpp

493 lines
19 KiB
C++

// EPLoader.cpp
// Dynamic ONNX Runtime EP loader.
// Loads onnxruntime.dll at runtime — no onnxruntime.lib linkage required.
//
// Moved from ONNXEngine/ to ANSLibsLoader/.
// Compile this file in EXACTLY ONE project (ANSLibsLoader.dll).
// That project MUST define ANSLIBSLOADER_EXPORTS in its Preprocessor Definitions.
//
// Windows: LoadLibraryExW + AddDllDirectory + GetProcAddress
// Linux: dlopen (RTLD_NOW | RTLD_GLOBAL) + dlsym
#include "EPLoader.h"
#include "DynLibUtils.h"
// ORT C++ headers — included ONLY in this translation unit.
// ORT_API_MANUAL_INIT must be defined project-wide (Preprocessor Definitions)
// in every project that includes ORT headers, so all translation units see
// Ort::Global<void>::api_ as an extern rather than a default-constructed object.
#include <onnxruntime_cxx_api.h>
#include <algorithm>
#include <iostream>
#include <stdexcept>
#include <string>
#include <vector>
#ifdef _WIN32
# ifndef WIN32_LEAN_AND_MEAN
# define WIN32_LEAN_AND_MEAN
# endif
# include <windows.h>
#else
# include <dlfcn.h>
#endif
namespace ANSCENTER {
// ── Static member definitions ────────────────────────────────────────────
#ifdef ANSLIBSLOADER_EXPORTS
std::mutex EPLoader::s_mutex;
bool EPLoader::s_initialized = false;
EPInfo EPLoader::s_info;
# ifdef _WIN32
std::string EPLoader::s_temp_ort_path;
std::string EPLoader::s_temp_dir;
# endif
#endif
// ── File-scope state ─────────────────────────────────────────────────────
#ifdef _WIN32
static HMODULE s_ort_module = nullptr;
#else
static void* s_ort_module = nullptr;
#endif
static const OrtApi* s_ort_api = nullptr;
// ════════════════════════════════════════════════════════════════════════
// File-scope helpers (anonymous namespace — not exported)
// ════════════════════════════════════════════════════════════════════════
namespace {
// ── GetOrtApi ────────────────────────────────────────────────────────
const OrtApi* GetOrtApi()
{
if (!s_ort_module)
throw std::runtime_error(
"[EPLoader] ORT DLL not loaded — call EPLoader::Current() first.");
if (s_ort_api)
return s_ort_api;
#ifdef _WIN32
using Fn = const OrtApiBase* (ORT_API_CALL*)();
auto fn = reinterpret_cast<Fn>(
GetProcAddress(s_ort_module, "OrtGetApiBase"));
#else
using Fn = const OrtApiBase* (ORT_API_CALL*)();
auto fn = reinterpret_cast<Fn>(
dlsym(s_ort_module, "OrtGetApiBase"));
#endif
if (!fn)
throw std::runtime_error(
"[EPLoader] OrtGetApiBase symbol not found in loaded ORT DLL.");
const OrtApiBase* base = fn();
if (!base)
throw std::runtime_error(
"[EPLoader] OrtGetApiBase() returned null.");
int dllMaxApi = ORT_API_VERSION;
{
const char* verStr = base->GetVersionString();
int major = 0, minor = 0;
if (verStr && sscanf(verStr, "%d.%d", &major, &minor) == 2)
dllMaxApi = minor;
}
int targetApi = std::min(ORT_API_VERSION, dllMaxApi);
const OrtApi* api = base->GetApi(targetApi);
if (!api)
throw std::runtime_error(
"[EPLoader] No compatible ORT API version found in loaded DLL.");
s_ort_api = api;
return s_ort_api;
}
} // anonymous namespace
// ════════════════════════════════════════════════════════════════════════
// EPLoader public / private methods
// ════════════════════════════════════════════════════════════════════════
const char* OrtDllName()
{
#ifdef _WIN32
return "onnxruntime.dll";
#elif defined(__APPLE__)
return "libonnxruntime.dylib";
#else
return "libonnxruntime.so";
#endif
}
const char* EPLoader::SubdirName(EngineType type)
{
switch (type) {
case EngineType::NVIDIA_GPU: return "cuda";
case EngineType::AMD_GPU: return "directml";
case EngineType::OPENVINO_GPU: return "openvino";
case EngineType::CPU: return "cpu";
default: return "cpu";
}
}
const char* EPLoader::EngineTypeName(EngineType type)
{
switch (type) {
case EngineType::NVIDIA_GPU: return "NVIDIA_GPU";
case EngineType::AMD_GPU: return "AMD_GPU";
case EngineType::OPENVINO_GPU: return "OPENVINO_GPU";
case EngineType::CPU: return "CPU";
case EngineType::AUTO_DETECT: return "AUTO_DETECT";
default: return "UNKNOWN";
}
}
/*static*/
std::string EPLoader::ResolveEPDir(const std::string& shared_dir,
EngineType type)
{
std::string ep_base = DynLib::JoinPath(shared_dir, "ep");
std::string subdir = DynLib::JoinPath(ep_base, SubdirName(type));
std::string dll_probe = DynLib::JoinPath(subdir, OrtDllName());
if (DynLib::FileExists(dll_probe)) {
std::cout << "[EPLoader] EP subdir found: " << subdir << std::endl;
return subdir;
}
std::string flat_probe = DynLib::JoinPath(shared_dir, OrtDllName());
if (DynLib::FileExists(flat_probe)) {
std::cout << "[EPLoader] EP subdir not found — "
"using flat Shared/ (backward compat): "
<< shared_dir << std::endl;
return shared_dir;
}
std::cerr << "[EPLoader] WARNING: " << OrtDllName() << " not found in:\n"
<< " " << subdir << "\n"
<< " " << shared_dir << "\n"
<< " LoadOrtDll will fail with a clear error."
<< std::endl;
return subdir;
}
EngineType EPLoader::AutoDetect()
{
std::cout << "[EPLoader] Auto-detecting hardware..." << std::endl;
ANSLicenseHelper helper;
EngineType detected = helper.CheckHardwareInformation();
std::cout << "[EPLoader] Detected: " << EngineTypeName(detected) << std::endl;
return detected;
}
/*static*/
const EPInfo& EPLoader::Initialize(const std::string& shared_dir,
EngineType preferred)
{
if (s_initialized) return s_info;
std::lock_guard<std::mutex> lock(s_mutex);
if (s_initialized) return s_info;
std::cout << "[EPLoader] Initializing..." << std::endl;
std::cout << "[EPLoader] Shared dir : " << shared_dir << std::endl;
std::cout << "[EPLoader] Preferred EP : " << EngineTypeName(preferred) << std::endl;
EngineType type = (preferred == EngineType::AUTO_DETECT)
? AutoDetect() : preferred;
std::string ep_dir = ResolveEPDir(shared_dir, type);
// When the EP lives in a subdirectory (e.g. ep/openvino/), provider
// DLLs may depend on runtime libraries that live in the parent
// shared_dir (e.g. openvino.dll). Inject shared_dir into the DLL
// search path so Windows can resolve those dependencies.
if (ep_dir != shared_dir)
DynLib::InjectDllSearchPath(shared_dir);
LoadOrtDll(ep_dir);
s_info.type = type;
s_info.libraryDir = ep_dir;
s_info.fromSubdir = (ep_dir != shared_dir);
s_initialized = true;
std::cout << "[EPLoader] Ready. EP=" << EngineTypeName(type)
<< " dir=" << ep_dir << std::endl;
return s_info;
}
/*static*/
const EPInfo& EPLoader::Current()
{
if (!s_initialized)
return Initialize(DEFAULT_SHARED_DIR, EngineType::AUTO_DETECT);
return s_info;
}
/*static*/
bool EPLoader::IsInitialized()
{
return s_initialized;
}
void* EPLoader::GetOrtApiRaw()
{
Current();
const OrtApi* api = GetOrtApi();
if (!api)
throw std::runtime_error(
"[EPLoader] GetOrtApiRaw: OrtApi not available.");
return static_cast<void*>(const_cast<OrtApi*>(api));
}
#ifdef _WIN32
// ── MakeTempDir ──────────────────────────────────────────────────────────
static std::string MakeTempDir()
{
char tmp[MAX_PATH] = {};
GetTempPathA(MAX_PATH, tmp);
std::string dir = std::string(tmp)
+ "anscenter_ort_"
+ std::to_string(GetCurrentProcessId());
CreateDirectoryA(dir.c_str(), nullptr);
return dir;
}
// ── CopyDirToTemp ─────────────────────────────────────────────────────────
static void CopyDirToTemp(const std::string& ep_dir,
const std::string& temp_dir)
{
std::string pattern = ep_dir + "\\*.dll";
WIN32_FIND_DATAA fd{};
HANDLE hFind = FindFirstFileA(pattern.c_str(), &fd);
if (hFind == INVALID_HANDLE_VALUE) {
std::cerr << "[EPLoader] WARNING: No DLLs found in ep_dir: "
<< ep_dir << std::endl;
return;
}
int copied = 0, skipped = 0;
do {
std::string src = ep_dir + "\\" + fd.cFileName;
std::string dst = temp_dir + "\\" + fd.cFileName;
if (CopyFileA(src.c_str(), dst.c_str(), /*bFailIfExists=*/FALSE)) {
++copied;
std::cout << "[EPLoader] copied: " << fd.cFileName << std::endl;
}
else {
DWORD err = GetLastError();
if (err != ERROR_FILE_EXISTS && err != ERROR_ALREADY_EXISTS) {
std::cerr << "[EPLoader] WARNING: could not copy "
<< fd.cFileName
<< " (err=" << err << ") — skipping." << std::endl;
}
++skipped;
}
} while (FindNextFileA(hFind, &fd));
FindClose(hFind);
std::cout << "[EPLoader] Staged " << copied << " DLL(s) to temp dir"
<< (skipped ? " (" + std::to_string(skipped) + " already present)" : "")
<< std::endl;
}
// ── DeleteTempDir ─────────────────────────────────────────────────────────
static void DeleteTempDir(const std::string& dir)
{
if (dir.empty()) return;
std::string pattern = dir + "\\*";
WIN32_FIND_DATAA fd{};
HANDLE hFind = FindFirstFileA(pattern.c_str(), &fd);
if (hFind != INVALID_HANDLE_VALUE) {
do {
if (strcmp(fd.cFileName, ".") == 0 ||
strcmp(fd.cFileName, "..") == 0) continue;
std::string f = dir + "\\" + fd.cFileName;
DeleteFileA(f.c_str());
} while (FindNextFileA(hFind, &fd));
FindClose(hFind);
}
RemoveDirectoryA(dir.c_str());
std::cout << "[EPLoader] Temp staging dir deleted: " << dir << std::endl;
}
#endif // _WIN32
// ── LoadOrtDll ───────────────────────────────────────────────────────────
void EPLoader::LoadOrtDll(const std::string& ep_dir)
{
if (s_ort_module) {
std::cout << "[EPLoader] ORT DLL already loaded — skipping." << std::endl;
return;
}
// Inject ep_dir into the DLL search path so cudart64_*.dll and other
// CUDA runtime DLLs that are NOT copied to temp are still found.
DynLib::InjectDllSearchPath(ep_dir);
std::string src_path = DynLib::JoinPath(ep_dir, OrtDllName());
std::cout << "[EPLoader] ORT source : " << src_path << std::endl;
#ifdef _WIN32
// ── Windows: stage ALL ep_dir DLLs into a process-unique temp folder ────
std::string temp_dir = MakeTempDir();
std::cout << "[EPLoader] Staging dir : " << temp_dir << std::endl;
CopyDirToTemp(ep_dir, temp_dir);
std::string ort_alias = temp_dir + "\\anscenter_ort_"
+ std::to_string(GetCurrentProcessId()) + ".dll";
std::string ort_in_temp = temp_dir + "\\" + OrtDllName();
if (!CopyFileA(ort_in_temp.c_str(), ort_alias.c_str(), /*bFailIfExists=*/FALSE)) {
DWORD err = GetLastError();
if (err != ERROR_FILE_EXISTS && err != ERROR_ALREADY_EXISTS) {
throw std::runtime_error(
"[EPLoader] Failed to create ORT alias in temp dir.\n"
" src : " + ort_in_temp + "\n"
" dst : " + ort_alias + "\n"
" err : " + std::to_string(err));
}
}
std::cout << "[EPLoader] ORT alias : " << ort_alias << std::endl;
DynLib::InjectDllSearchPath(temp_dir);
HMODULE hExisting = GetModuleHandleA("onnxruntime.dll");
if (hExisting) {
char existingPath[MAX_PATH] = {};
GetModuleFileNameA(hExisting, existingPath, MAX_PATH);
std::string existingStr(existingPath);
std::transform(existingStr.begin(), existingStr.end(),
existingStr.begin(), ::tolower);
if (existingStr.find("system32") != std::string::npos ||
existingStr.find("syswow64") != std::string::npos) {
std::cerr << "[EPLoader] WARNING: System ORT is resident ("
<< existingPath << ") - our alias overrides it." << std::endl;
}
}
// Safe UTF-8 → wide-string conversion (handles non-ASCII paths correctly)
int wlen = MultiByteToWideChar(CP_UTF8, 0, ort_alias.c_str(), -1, nullptr, 0);
std::wstring walias(static_cast<size_t>(wlen > 0 ? wlen - 1 : 0), L'\0');
if (wlen > 0)
MultiByteToWideChar(CP_UTF8, 0, ort_alias.c_str(), -1, walias.data(), wlen);
s_ort_module = LoadLibraryExW(
walias.c_str(),
nullptr,
LOAD_WITH_ALTERED_SEARCH_PATH);
if (!s_ort_module) {
DWORD err = GetLastError();
DeleteTempDir(temp_dir);
throw std::runtime_error(
"[EPLoader] LoadLibraryExW failed: " + ort_alias +
"\n Windows error code: " + std::to_string(err) +
"\n Ensure all CUDA runtime DLLs (cudart64_*.dll etc.)"
"\n exist in: " + ep_dir);
}
s_temp_ort_path = ort_alias;
s_temp_dir = temp_dir;
#else
// ── Linux / macOS ─────────────────────────────────────────────────────
s_ort_module = dlopen(src_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (!s_ort_module) {
throw std::runtime_error(
std::string("[EPLoader] dlopen failed: ") + src_path +
"\n " + dlerror());
}
#endif
// ── Bootstrap ORT C++ API ─────────────────────────────────────────────
using OrtGetApiBase_fn = const OrtApiBase* (ORT_API_CALL*)();
#ifdef _WIN32
auto fn = reinterpret_cast<OrtGetApiBase_fn>(
GetProcAddress(s_ort_module, "OrtGetApiBase"));
#else
auto fn = reinterpret_cast<OrtGetApiBase_fn>(
dlsym(s_ort_module, "OrtGetApiBase"));
#endif
if (!fn)
throw std::runtime_error(
"[EPLoader] OrtGetApiBase not exported — is this a genuine onnxruntime build?\n"
" path: " + src_path);
const OrtApiBase* base = fn();
if (!base)
throw std::runtime_error(
"[EPLoader] OrtGetApiBase() returned null from: " + src_path);
// ── Version negotiation ───────────────────────────────────────────────
int dllMaxApi = ORT_API_VERSION;
{
const char* verStr = base->GetVersionString();
int major = 0, minor = 0;
if (verStr && sscanf(verStr, "%d.%d", &major, &minor) == 2)
dllMaxApi = minor;
}
int targetApi = std::min(ORT_API_VERSION, dllMaxApi);
if (targetApi < ORT_API_VERSION) {
std::cerr << "[EPLoader] WARNING: ORT DLL version "
<< base->GetVersionString()
<< " supports up to API " << dllMaxApi
<< " but headers expect API " << ORT_API_VERSION << ".\n"
<< " Using API " << targetApi
<< ". Consider upgrading onnxruntime.dll in ep/ to match SDK headers."
<< std::endl;
}
const OrtApi* api = base->GetApi(targetApi);
if (!api)
throw std::runtime_error(
"[EPLoader] GetApi(" + std::to_string(targetApi) +
") returned null — the DLL may be corrupt.");
s_ort_api = api;
Ort::Global<void>::api_ = api;
std::cout << "[EPLoader] ORT loaded successfully." << std::endl;
std::cout << "[EPLoader] ORT DLL version : " << base->GetVersionString() << std::endl;
std::cout << "[EPLoader] ORT header API : " << ORT_API_VERSION << std::endl;
std::cout << "[EPLoader] ORT active API : " << targetApi << std::endl;
}
// ── Shutdown ──────────────────────────────────────────────────────────────
void EPLoader::Shutdown()
{
std::lock_guard<std::mutex> lock(s_mutex);
if (s_ort_module) {
std::cout << "[EPLoader] Unloading ORT DLL..." << std::endl;
#ifdef _WIN32
FreeLibrary(s_ort_module);
if (!s_temp_dir.empty()) {
DeleteTempDir(s_temp_dir);
s_temp_dir.clear();
}
s_temp_ort_path.clear();
#else
dlclose(s_ort_module);
#endif
s_ort_module = nullptr;
}
s_ort_api = nullptr;
s_initialized = false;
s_info = EPInfo{};
std::cout << "[EPLoader] Shutdown complete." << std::endl;
}
} // namespace ANSCENTER