diff --git a/include/mitsuba/core/thread.h b/include/mitsuba/core/thread.h index 5e46d273..f1da81fb 100644 --- a/include/mitsuba/core/thread.h +++ b/include/mitsuba/core/thread.h @@ -147,6 +147,9 @@ public: /// Shut down the threading system static void staticShutdown(); + /// Initialize Mitsuba's threading system for simultaneous use of OpenMP + static void initializeOpenMP(); + MTS_DECLARE_CLASS() protected: /// Virtual destructor diff --git a/src/libcore/thread.cpp b/src/libcore/thread.cpp index 7a952167..e2f4caac 100644 --- a/src/libcore/thread.cpp +++ b/src/libcore/thread.cpp @@ -39,6 +39,21 @@ protected: virtual ~MainThread() { } }; +class OpenMPThread : public Thread { +public: + OpenMPThread() : Thread("main") { + } + + virtual void run() { + Log(EError, "The OpenMP thread is already running!"); + } + + MTS_DECLARE_CLASS() +protected: + virtual ~OpenMPThread() { } +}; + + ThreadLocal *Thread::m_self = NULL; #if defined(__LINUX__) || defined(__OSX__) @@ -278,6 +293,26 @@ void Thread::staticShutdown() { #endif } +void Thread::initializeOpenMP() { + ref logger = Thread::getThread()->getLogger(); + ref fResolver = Thread::getThread()->getFileResolver(); + + #pragma omp parallel + { + Thread *thread = Thread::getThread(); + if (!thread) { + thread = new OpenMPThread(); + thread->m_running = true; + thread->m_thread = pthread_self(); + thread->m_joinMutex = new Mutex(); + thread->m_joined = false; + thread->m_fresolver = fResolver; + thread->m_logger = logger; + m_self->set(thread); + } + } +} + Thread::~Thread() { if (m_running) Log(EWarn, "Destructor called while Thread '%s' is still running", m_name.c_str()); @@ -285,4 +320,5 @@ Thread::~Thread() { MTS_IMPLEMENT_CLASS(Thread, true, Object) MTS_IMPLEMENT_CLASS(MainThread, false, Thread) +MTS_IMPLEMENT_CLASS(OpenMPThread, false, Thread) MTS_NAMESPACE_END diff --git a/src/librender/gatherproc.cpp b/src/librender/gatherproc.cpp index 19a621bc..1419efaa 100644 --- a/src/librender/gatherproc.cpp +++ b/src/librender/gatherproc.cpp @@ -124,8 +124,8 @@ public: m_workResult->put(PhotonMap::Photon(its.p, its.geoFrame.n, -its.toWorld(its.wi), weight, depth)); } - void handleMediumInteraction(int depth, bool caustic, const MediumSamplingRecord &mRec, const Vector &wi, - const Spectrum &weight) { + void handleMediumInteraction(int depth, bool caustic, const MediumSamplingRecord &mRec, Float time, + const Vector &wi, const Spectrum &weight) { if (m_type == GatherPhotonProcess::EVolumePhotons && depth > 1) m_workResult->put(PhotonMap::Photon(mRec.p, Normal(), -wi, weight, depth)); }